diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 51f49da15..50fb922fe 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 @json_schema_type @@ -35,36 +36,57 @@ EvalCandidate = Annotated[ ] +@json_schema_type +class BenchmarkEvalTaskConfig(BaseModel): + type: Literal["benchmark"] = "benchmark" + eval_candidate: EvalCandidate + + +@json_schema_type +class AppEvalTaskConfig(BaseModel): + type: Literal["app"] = "app" + eval_candidate: EvalCandidate + scoring_params: Dict[str, ScoringFnParams] = Field( + description="Map between scoring function id and parameters for each scoring function you want to run", + default_factory=dict, + ) + # we could optinally add any specific dataset config here + + +EvalTaskConfig = Annotated[ + Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") +] + + @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] - # each key in the dict is a scoring function name scores: Dict[str, ScoringResult] class Eval(Protocol): - @webmethod(route="/eval/evaluate_batch", method="POST") - async def evaluate_batch( + @webmethod(route="/eval/run_eval", method="POST") + async def run_eval( self, - dataset_id: str, - candidate: EvalCandidate, - scoring_functions: List[str], + task_id: str, + task_config: EvalTaskConfig, ) -> Job: ... - @webmethod(route="/eval/evaluate", method="POST") - async def evaluate( + @webmethod(route="/eval/evaluate_rows", method="POST") + async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], - candidate: EvalCandidate, scoring_functions: List[str], + task_config: EvalTaskConfig, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> Optional[JobStatus]: ... + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel(self, job_id: str) -> None: ... + async def job_cancel(self, task_id: str, job_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str) -> EvaluateResponse: ... + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/apis/eval_tasks/__init__.py b/llama_stack/apis/eval_tasks/__init__.py new file mode 100644 index 000000000..7ca216706 --- /dev/null +++ b/llama_stack/apis/eval_tasks/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval_tasks import * # noqa: F401 F403 diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py new file mode 100644 index 000000000..0007066aa --- /dev/null +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel, Field + + +@json_schema_type +class EvalTaskDef(BaseModel): + identifier: str + dataset_id: str + scoring_functions: List[str] + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata for this evaluation task", + ) + + +@json_schema_type +class EvalTaskDefWithProvider(EvalTaskDef): + type: Literal["eval_task"] = "eval_task" + provider_id: str = Field( + description="ID of the provider which serves this dataset", + ) + + +@runtime_checkable +class EvalTasks(Protocol): + @webmethod(route="/eval_tasks/list", method="GET") + async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/get", method="GET") + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/register", method="POST") + async def register_eval_task( + self, eval_task_def: EvalTaskDefWithProvider + ) -> None: ... diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 1fd523dcb..c2bfdcd23 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,11 +48,13 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index d0a9cc597..140376242 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,34 +4,66 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from enum import Enum +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType -@json_schema_type -class Parameter(BaseModel): - name: str - type: ParamType - description: Optional[str] = None - - # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? +@json_schema_type +class ScoringConfigType(Enum): + llm_as_judge = "llm_as_judge" + regex_parser = "regex_parser" -class LLMAsJudgeContext(BaseModel): +@json_schema_type +class LLMAsJudgeScoringFnParams(BaseModel): + type: Literal[ScoringConfigType.llm_as_judge.value] = ( + ScoringConfigType.llm_as_judge.value + ) judge_model: str prompt_template: Optional[str] = None - judge_score_regex: Optional[List[str]] = Field( - description="Regex to extract the score from the judge response", - default=None, + judge_score_regexes: Optional[List[str]] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, ) +@json_schema_type +class RegexParserScoringFnParams(BaseModel): + type: Literal[ScoringConfigType.regex_parser.value] = ( + ScoringConfigType.regex_parser.value + ) + parsing_regexes: Optional[List[str]] = Field( + description="Regex to extract the answer from generated response", + default_factory=list, + ) + + +ScoringFnParams = Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + ], + Field(discriminator="type"), +] + + @json_schema_type class ScoringFnDef(BaseModel): identifier: str @@ -40,14 +72,13 @@ class ScoringFnDef(BaseModel): default_factory=dict, description="Any additional metadata for this definition", ) - parameters: List[Parameter] = Field( - description="List of parameters for the deterministic function", - default_factory=list, - ) return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[LLMAsJudgeContext] = None + params: Optional[ScoringFnParams] = Field( + description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", + default=None, + ) # We can optionally add information here to support packaging of code, etc. diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 2149162a6..3fc3b2d5d 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.scoring_functions, router_api=Api.scoring, ), + AutoRoutedApiInfo( + routing_table_api=Api.eval_tasks, + router_api=Api.eval, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9b8e41561..aac7ae5b6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -17,6 +17,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -48,6 +49,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring: Scoring, Api.scoring_functions: ScoringFunctions, Api.eval: Eval, + Api.eval_tasks: EvalTasks, } @@ -58,6 +60,7 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.safety: (ShieldsProtocolPrivate, Shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets), Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), + Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index b3ebd1368..57e81ac30 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,6 +12,7 @@ from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( DatasetsRoutingTable, + EvalTasksRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, ScoringFunctionsRoutingTable, @@ -31,6 +32,7 @@ async def get_routing_table_impl( "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, + "eval_tasks": EvalTasksRoutingTable, } if api.value not in api_to_tables: @@ -44,6 +46,7 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: from .routers import ( DatasetIORouter, + EvalRouter, InferenceRouter, MemoryRouter, SafetyRouter, @@ -56,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "safety": SafetyRouter, "datasetio": DatasetIORouter, "scoring": ScoringRouter, + "eval": EvalRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 760dbaf2f..8edf950b2 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.eval import * # noqa: F403 class MemoryRouter(Memory): @@ -211,16 +212,16 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score_batch( dataset_id=dataset_id, - scoring_functions=[fn_identifier], + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -232,17 +233,87 @@ class ScoringRouter(Scoring): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score( input_rows=input_rows, - scoring_functions=[fn_identifier], + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_eval( + self, + task_id: str, + task_config: AppEvalTaskConfig, + ) -> Job: + return await self.routing_table.get_provider_impl(task_id).run_eval( + task_id=task_id, + task_config=task_config, + ) + + @webmethod(route="/eval/evaluate_rows", method="POST") + async def evaluate_rows( + self, + task_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: EvalTaskConfig, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(task_id).evaluate_rows( + task_id=task_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + task_config=task_config, + ) + + async def job_status( + self, + task_id: str, + job_id: str, + ) -> Optional[JobStatus]: + return await self.routing_table.get_provider_impl(task_id).job_status( + task_id, job_id + ) + + async def job_cancel( + self, + task_id: str, + job_id: str, + ) -> None: + await self.routing_table.get_provider_impl(task_id).job_cancel( + task_id, + job_id, + ) + + async def job_result( + self, + task_id: str, + job_id: str, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(task_id).job_result( + task_id, + job_id, + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index bcf125bec..a676b5fef 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -12,6 +12,8 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 + from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 @@ -40,6 +42,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_dataset(obj) elif api == Api.scoring: await p.register_scoring_function(obj) + elif api == Api.eval: + await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -103,6 +107,11 @@ class CommonRoutingTableImpl(RoutingTable): scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + elif api == Api.eval: + p.eval_task_store = self + eval_tasks = await p.list_eval_tasks() + await add_objects(eval_tasks, pid, EvalTaskDefWithProvider) + async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() @@ -121,6 +130,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): return ("Scoring", "scoring_function") + elif isinstance(self, EvalTasksRoutingTable): + return ("Eval", "eval_task") else: raise ValueError("Unknown routing table type") @@ -246,9 +257,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): await self.register_object(dataset_def) -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all_with_type("scoring_function") + return await self.get_all_with_type("scoring_fn") async def get_scoring_function( self, name: str @@ -259,3 +270,14 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): self, function_def: ScoringFnDefWithProvider ) -> None: await self.register_object(function_def) + + +class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): + async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: + return await self.get_all_with_type("eval_task") + + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: + return await self.get_object_by_identifier(name) + + async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: + await self.register_object(eval_task_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 59c5a38fa..0f82ca592 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -12,6 +12,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef from llama_stack.apis.scoring_functions import ScoringFnDef @@ -35,6 +36,7 @@ class Api(Enum): memory_banks = "memory_banks" datasets = "datasets" scoring_functions = "scoring_functions" + eval_tasks = "eval_tasks" # built-in API inspect = "inspect" @@ -70,6 +72,12 @@ class ScoringFunctionsProtocolPrivate(Protocol): async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... +class EvalTasksProtocolPrivate(Protocol): + async def list_eval_tasks(self) -> List[EvalTaskDef]: ... + + async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 3aec6170f..4a61c9d93 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -6,13 +6,15 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.common.job_types import Job from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring +from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig @@ -25,7 +27,7 @@ class ColumnName(Enum): generated_answer = "generated_answer" -class MetaReferenceEvalImpl(Eval): +class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): def __init__( self, config: MetaReferenceEvalConfig, @@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval): # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} + self.eval_tasks = {} + async def initialize(self) -> None: ... async def shutdown(self) -> None: ... + async def register_eval_task(self, task_def: EvalTaskDef) -> None: + self.eval_tasks[task_def.identifier] = task_def + + async def list_eval_tasks(self) -> List[EvalTaskDef]: + return list(self.eval_tasks.values()) + async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: @@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def evaluate_batch( + async def run_eval( self, - dataset_id: str, - candidate: EvalCandidate, - scoring_functions: List[str], + task_id: str, + task_config: EvalTaskConfig, ) -> Job: + task_def = self.eval_tasks[task_id] + dataset_id = task_def.dataset_id + candidate = task_config.eval_candidate + scoring_functions = task_def.scoring_functions + await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, ) - res = await self.evaluate( + res = await self.evaluate_rows( + task_id=task_id, input_rows=all_rows.rows, - candidate=candidate, scoring_functions=scoring_functions, + task_config=task_config, ) # TODO: currently needs to wait for generation before returning @@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval): self.jobs[job_id] = res return Job(job_id=job_id) - async def evaluate( + async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], - candidate: EvalCandidate, scoring_functions: List[str], + task_config: EvalTaskConfig, ) -> EvaluateResponse: + candidate = task_config.eval_candidate if candidate.type == "agent": raise NotImplementedError( "Evaluation with generation has not been implemented for agents" @@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval): } ) elif ColumnName.chat_completion_input.value in x: - input_messages = eval(str(x[ColumnName.chat_completion_input.value])) + chat_completion_input_str = str( + x[ColumnName.chat_completion_input.value] + ) + input_messages = eval(chat_completion_input_str) input_messages = [UserMessage(**x) for x in input_messages] messages = [] if candidate.system_message: @@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval): for input_r, generated_r in zip(input_rows, generations) ] + if task_config.type == "app" and task_config.scoring_params is not None: + scoring_functions_dict = { + scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) + for scoring_fn_id in scoring_functions + } + else: + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } + score_response = await self.scoring_api.score( - input_rows=score_input_rows, scoring_functions=scoring_functions + input_rows=score_input_rows, scoring_functions=scoring_functions_dict ) return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str) -> Optional[JobStatus]: + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str) -> None: + async def job_cancel(self, task_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str) -> EvaluateResponse: - status = await self.job_status(job_id) + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: + status = await self.job_status(task_id, job_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 709b2f0c6..c4add966d 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - self.llm_as_judge_fn.register_scoring_fn_def(function_def) - self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn + raise NotImplementedError("Register scoring function not implemented yet") async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -97,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -106,7 +105,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): rows_in_page=-1, ) res = await self.score( - input_rows=all_rows.rows, scoring_functions=scoring_functions + input_rows=all_rows.rows, + scoring_functions=scoring_functions, ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -118,14 +118,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} - for scoring_fn_id in scoring_functions: + for scoring_fn_id in scoring_functions.keys(): if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - score_results = await scoring_fn.score(input_rows, scoring_fn_id) + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py index cbd875be6..532686ebd 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -36,7 +36,10 @@ class BaseScoringFn(ABC): @abstractmethod async def score_row( - self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: raise NotImplementedError() @@ -50,8 +53,9 @@ class BaseScoringFn(ABC): self, input_rows: List[Dict[str, Any]], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> List[ScoringResultRow]: return [ - await self.score_row(input_row, scoring_fn_identifier) + await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows ] diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 2a0cd0578..07405d56c 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "equality", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 20a67edc7..cfef52160 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -28,9 +28,13 @@ llm_as_judge_8b_correctness = ScoringFnDef( description="Llm As Judge Scoring Function", parameters=[], return_type=NumberType(), - context=LLMAsJudgeContext( + params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], + judge_score_regexes=[ + r"Total rating: (\d+)", + r"rating: (\d+)", + r"Rating: (\d+)", + ], ), ) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 84dd28fd7..f98f7fb5e 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -36,31 +36,37 @@ class LlmAsJudgeScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert ( scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}." + + # override params if scoring_params is provided + if scoring_params is not None: + fn_def.params = scoring_params + + assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert ( - fn_def.context.prompt_template is not None + fn_def.params.prompt_template is not None ), "LLM Judge prompt_template not found." assert ( - fn_def.context.judge_score_regex is not None - ), "LLM Judge judge_score_regex not found." + fn_def.params.judge_score_regexes is not None + ), "LLM Judge judge_score_regexes not found." input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] - judge_input_msg = fn_def.context.prompt_template.format( + judge_input_msg = fn_def.params.prompt_template.format( input_query=input_query, expected_answer=expected_answer, generated_answer=generated_answer, ) judge_response = await self.inference_api.chat_completion( - model=fn_def.context.judge_model, + model=fn_def.params.judge_model, messages=[ { "role": "user", @@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ], ) content = judge_response.completion_message.content - rating_regexs = fn_def.context.judge_score_regex + rating_regexes = fn_def.params.judge_score_regexes judge_rating = None - for regex in rating_regexs: + for regex in rating_regexes: match = re.search(regex, content) if match: judge_rating = int(match.group(1)) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index f42964c1f..289c63dd7 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "subset_of", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 2278e1a6c..3bec2d11d 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -153,4 +153,7 @@ pytest_plugins = [ "llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.agents.fixtures", + "llama_stack.providers.tests.datasetio.fixtures", + "llama_stack.providers.tests.scoring.fixtures", + "llama_stack.providers.tests.eval.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/conftest.py b/llama_stack/providers/tests/datasetio/conftest.py new file mode 100644 index 000000000..740eddb33 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from .fixtures import DATASETIO_FIXTURES + + +def pytest_configure(config): + for fixture_name in DATASETIO_FIXTURES: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_generate_tests(metafunc): + if "datasetio_stack" in metafunc.fixturenames: + metafunc.parametrize( + "datasetio_stack", + [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in DATASETIO_FIXTURES + ], + indirect=True, + ) diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py new file mode 100644 index 000000000..7d7615b55 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def datasetio_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def datasetio_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +DATASETIO_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def datasetio_stack(request): + fixture_name = request.param + fixture = request.getfixturevalue(f"datasetio_{fixture_name}") + + impls = await resolve_impls_for_test_v2( + [Api.datasetio], + {"datasetio": fixture.providers}, + fixture.provider_data, + ) + + return impls[Api.datasetio], impls[Api.datasets] diff --git a/llama_stack/providers/tests/datasetio/provider_config_example.yaml b/llama_stack/providers/tests/datasetio/provider_config_example.yaml deleted file mode 100644 index c0565a39e..000000000 --- a/llama_stack/providers/tests/datasetio/provider_config_example.yaml +++ /dev/null @@ -1,4 +0,0 @@ -providers: - - provider_id: test-meta - provider_type: meta-reference - config: {} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 866b1e270..c02794c50 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -3,11 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + import os import pytest -import pytest_asyncio - from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 @@ -15,35 +14,11 @@ import base64 import mimetypes from pathlib import Path -from llama_stack.providers.tests.resolver import resolve_impls_for_test - # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \ -# --tb=short --disable-warnings -# ``` - - -@pytest_asyncio.fixture(scope="session") -async def datasetio_settings(): - impls = await resolve_impls_for_test( - Api.datasetio, - ) - return { - "datasetio_impl": impls[Api.datasetio], - "datasets_impl": impls[Api.datasets], - } +# pytest llama_stack/providers/tests/datasetio/test_datasetio.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings def data_url_from_file(file_path: str) -> str: @@ -82,8 +57,7 @@ async def register_dataset( dataset = DatasetDefWithProvider( identifier=dataset_id, - provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) - or os.environ["PROVIDER_ID"], + provider_id="", url=URL( uri=test_url, ), @@ -92,57 +66,47 @@ async def register_dataset( await datasets_impl.register_dataset(dataset) -@pytest.mark.asyncio -async def test_datasets_list(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 +class TestDatasetIO: + @pytest.mark.asyncio + async def test_datasets_list(self, datasetio_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, datasets_impl = datasetio_stack + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + @pytest.mark.asyncio + async def test_register_dataset(self, datasetio_stack): + _, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" -@pytest.mark.asyncio -async def test_datasets_register(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) + @pytest.mark.asyncio + async def test_get_rows_paginated(self, datasetio_stack): + datasetio_impl, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 + provider = datasetio_impl.routing_table.get_provider_impl("test_dataset") + if provider.__provider_spec__.provider_type == "remote": + pytest.skip("remote provider doesn't support get_rows_paginated") - # register same dataset with same id again will fail - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - -@pytest.mark.asyncio -async def test_get_rows_paginated(datasetio_settings): - datasetio_impl = datasetio_settings["datasetio_impl"] - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - # iterate over all rows - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py new file mode 100644 index 000000000..064feb611 --- /dev/null +++ b/llama_stack/providers/tests/eval/conftest.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from ..scoring.fixtures import SCORING_FIXTURES +from .fixtures import EVAL_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_eval_fireworks_inference", + marks=pytest.mark.meta_reference_eval_fireworks_inference, + ), + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_eval_together_inference", + marks=pytest.mark.meta_reference_eval_together_inference, + ), +] + + +def pytest_configure(config): + for fixture_name in [ + "meta_reference_eval_fireworks_inference", + "meta_reference_eval_together_inference", + ]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "eval_stack" in metafunc.fixturenames: + available_fixtures = { + "eval": EVAL_FIXTURES, + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("eval_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py new file mode 100644 index 000000000..810239440 --- /dev/null +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def eval_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def eval_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +EVAL_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def eval_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["datasetio", "eval", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.eval, Api.datasetio, Api.inference, Api.scoring], + providers, + provider_data, + ) + + return impls diff --git a/llama_stack/providers/tests/eval/provider_config_example.yaml b/llama_stack/providers/tests/eval/provider_config_example.yaml deleted file mode 100644 index 38f7512f1..000000000 --- a/llama_stack/providers/tests/eval/provider_config_example.yaml +++ /dev/null @@ -1,22 +0,0 @@ -providers: - datasetio: - - provider_id: test-meta - provider_type: meta-reference - config: {} - scoring: - - provider_id: test-meta - provider_type: meta-reference - config: {} - eval: - - provider_id: test-meta - provider_type: meta-reference - config: {} - inference: - - provider_id: test-tgi - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 - - provider_id: test-tgi-2 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5010 diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 667be1bd5..a55a754c5 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -3,81 +3,124 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest -import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.eval.eval import ModelCandidate -from llama_stack.distribution.datatypes import * # noqa: F403 + +import pytest from llama_models.llama3.api import SamplingParams +from llama_stack.apis.eval.eval import ( + AppEvalTaskConfig, + EvalTaskDefWithProvider, + ModelCandidate, +) +from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test + # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/eval/test_eval.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/eval/test_eval.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def eval_settings(): - impls = await resolve_impls_for_test( - Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference] - ) - return { - "eval_impl": impls[Api.eval], - "scoring_impl": impls[Api.scoring], - "datasets_impl": impls[Api.datasets], - } +class Testeval: + @pytest.mark.asyncio + async def test_eval_tasks_list(self, eval_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + eval_tasks_impl = eval_stack[Api.eval_tasks] + response = await eval_tasks_impl.list_eval_tasks() + assert isinstance(response, list) + assert len(response) == 0 + @pytest.mark.asyncio + async def test_eval_evaluate_rows(self, eval_stack): + eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasetio], + eval_stack[Api.datasets], + ) + await register_dataset( + datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" + ) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset_for_eval", + rows_in_page=3, + ) + assert len(rows.rows) == 3 -@pytest.mark.asyncio -async def test_eval(eval_settings): - datasets_impl = eval_settings["datasets_impl"] - await register_dataset( - datasets_impl, - for_generation=True, - dataset_id="test_dataset_for_eval", - ) - - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - eval_impl = eval_settings["eval_impl"] - response = await eval_impl.evaluate_batch( - dataset_id=response[0].identifier, - candidate=ModelCandidate( - model="Llama3.2-1B-Instruct", - sampling_params=SamplingParams(), - ), - scoring_functions=[ - "meta-reference::subset_of", + scoring_functions = [ "meta-reference::llm_as_judge_8b_correctness", - ], - ) - assert response.job_id == "0" - job_status = await eval_impl.job_status(response.job_id) + "meta-reference::equality", + ] + task_id = "meta-reference::app_eval" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) - assert job_status and job_status.value == "completed" + response = await eval_impl.evaluate_rows( + task_id=task_id, + input_rows=rows.rows, + scoring_functions=scoring_functions, + task_config=AppEvalTaskConfig( + eval_candidate=ModelCandidate( + model="Llama3.2-3B-Instruct", + sampling_params=SamplingParams(), + ), + ), + ) + assert len(response.generations) == 3 + assert "meta-reference::llm_as_judge_8b_correctness" in response.scores + assert "meta-reference::equality" in response.scores - eval_response = await eval_impl.job_result(response.job_id) + @pytest.mark.asyncio + async def test_eval_run_eval(self, eval_stack): + eval_impl, eval_tasks_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasets], + ) + await register_dataset( + datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" + ) - assert eval_response is not None - assert len(eval_response.generations) == 5 - assert "meta-reference::subset_of" in eval_response.scores - assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores + scoring_functions = [ + "meta-reference::llm_as_judge_8b_correctness", + "meta-reference::subset_of", + ] + + task_id = "meta-reference::app_eval-2" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) + response = await eval_impl.run_eval( + task_id=task_id, + task_config=AppEvalTaskConfig( + eval_candidate=ModelCandidate( + model="Llama3.2-3B-Instruct", + sampling_params=SamplingParams(), + ), + ), + ) + assert response.job_id == "0" + job_status = await eval_impl.job_status(task_id, response.job_id) + assert job_status and job_status.value == "completed" + eval_response = await eval_impl.job_result(task_id, response.job_id) + + assert eval_response is not None + assert len(eval_response.generations) == 5 + assert "meta-reference::subset_of" in eval_response.scores + assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 5b047549b..1698d7584 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -64,6 +64,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + print("!!!", inference_model) if "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py new file mode 100644 index 000000000..ee578f9b3 --- /dev/null +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from .fixtures import SCORING_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_scoring_fireworks_inference", + marks=pytest.mark.meta_reference_scoring_fireworks_inference, + ), + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_scoring_together_inference", + marks=pytest.mark.meta_reference_scoring_together_inference, + ), +] + + +def pytest_configure(config): + for fixture_name in [ + "meta_reference_scoring_fireworks_inference", + "meta_reference_scoring_together_inference", + ]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "scoring_stack" in metafunc.fixturenames: + available_fixtures = { + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("scoring_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py new file mode 100644 index 000000000..925f98779 --- /dev/null +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def scoring_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def scoring_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def scoring_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["datasetio", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.scoring, Api.datasetio, Api.inference], + providers, + provider_data, + ) + + return ( + impls[Api.scoring], + impls[Api.scoring_functions], + impls[Api.datasetio], + impls[Api.datasets], + ) diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml deleted file mode 100644 index 6a9c0d842..000000000 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ /dev/null @@ -1,17 +0,0 @@ -providers: - datasetio: - - provider_id: test-meta - provider_type: meta-reference - config: {} - scoring: - - provider_id: test-meta - provider_type: meta-reference - config: {} - - provider_id: test-braintrust - provider_type: braintrust - config: {} - inference: - - provider_id: tgi0 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index b9b920739..3c1b6554f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -3,150 +3,109 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest -import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 + +import pytest + +from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/scoring/test_scoring.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def scoring_settings(): - impls = await resolve_impls_for_test( - Api.scoring, deps=[Api.datasetio, Api.inference] - ) - return { - "scoring_impl": impls[Api.scoring], - "scoring_functions_impl": impls[Api.scoring_functions], - "datasets_impl": impls[Api.datasets], - } +class TestScoring: + @pytest.mark.asyncio + async def test_scoring_functions_list(self, scoring_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, scoring_functions_impl, _, _ = scoring_stack + response = await scoring_functions_impl.list_scoring_functions() + assert isinstance(response, list) + assert len(response) > 0 - -@pytest_asyncio.fixture(scope="session") -async def provider_scoring_functions(): - return { - "meta-reference": { - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - }, - "braintrust": { - "braintrust::factuality", - "braintrust::answer-correctness", - }, - } - - -@pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - # get current provider_type we're testing - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - for x in provider_scoring_functions[provider_type]: - assert x in function_ids - - -@pytest.mark.asyncio -async def test_scoring_functions_register(scoring_settings): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - datasets_impl = scoring_settings["datasets_impl"] - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - if provider_type not in ("meta-reference"): - pytest.skip( - "Other scoring providers don't support registering scoring functions." + @pytest.mark.asyncio + async def test_scoring_score(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack ) + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 - test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ - # register the scoring function - await scoring_functions_impl.register_scoring_function( - ScoringFnDefWithProvider( - identifier="meta-reference::llm_as_judge_8b_random", - description="Llm As Judge Scoring Function", - parameters=[], - return_type=NumberType(), - context=LLMAsJudgeContext( - prompt_template=test_prompt, - judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Number: (\d+)"], - ), - provider_id="test-meta", + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, ) - ) + assert len(rows.rows) == 3 - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::llm_as_judge_8b_random" in function_ids + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": None, + "meta-reference::equality": None, + } + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) - # test score using newly registered scoring function - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::llm_as_judge_8b_random", - ], - ) - assert "meta-reference::llm_as_judge_8b_random" in response.results + # score batch + response = await scoring_impl.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5 + @pytest.mark.asyncio + async def test_scoring_score_with_params(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack + ) + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 -@pytest.mark.asyncio -async def test_scoring_score(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - datasets_impl = scoring_settings["datasets_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - await register_dataset(datasets_impl) + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 - response = await datasets_impl.list_datasets() - assert len(response) == 1 + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( + judge_model="Llama3.1-405B-Instruct", + prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", + judge_score_regexes=[r"Score: (\d+)"], + ) + } - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=list(provider_scoring_functions[provider_type]), - ) - - assert len(response.results) == len(provider_scoring_functions[provider_type]) - for x in provider_scoring_functions[provider_type]: - assert x in response.results + # score batch + response = await scoring_impl.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5