diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index aa8e4d5e9..9613ffc58 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -38,14 +38,16 @@ EvalCandidate = Annotated[ @json_schema_type class BenchmarkEvalTaskConfig(BaseModel): + type: Literal["benchmark"] = "benchmark" eval_candidate: EvalCandidate @json_schema_type class AppEvalTaskConfig(BaseModel): + type: Literal["app"] = "app" eval_candidate: EvalCandidate scoring_params: Dict[str, ScoringFnParams] = Field( - description="Map between scoring function id and parameters", + description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, ) # we could optinally add any specific dataset config here @@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel): class Eval(Protocol): - @webmethod(route="/eval/run_benchmark_eval", method="POST") - async def run_benchmark_eval( + @webmethod(route="/eval/run_benchmark", method="POST") + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: ... @webmethod(route="/eval/run_eval", method="POST") async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: ... @webmethod(route="/eval/evaluate_rows", method="POST") diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a518bd806..a68582057 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,8 +48,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -57,6 +56,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 341f84c36..b3906b225 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -76,7 +76,7 @@ class ScoringFnDef(BaseModel): description="The return type of the deterministic function", ) params: Optional[ScoringFnParams] = Field( # type: ignore - description="The parameters for the scoring function for benchmark eval, we could override this for app eval", + description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", default=None, ) # We can optionally add information here to support packaging of code, etc. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 18c78b06c..ca9da7368 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -216,18 +216,16 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score_batch( dataset_id=dataset_id, - scoring_functions=[fn_identifier], - scoring_params=scoring_params, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -241,18 +239,16 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score( input_rows=input_rows, - scoring_functions=[fn_identifier], - scoring_params=scoring_params, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -272,24 +268,24 @@ class EvalRouter(Eval): async def shutdown(self) -> None: pass - async def run_benchmark_eval( + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: pass async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: # NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals return await self.routing_table.get_provider_impl( DEFAULT_EVAL_TASK_IDENTIFIER ).run_eval( - eval_task_def=eval_task_def, - eval_task_config=eval_task_config, + task=task, + task_config=task_config, ) @webmethod(route="/eval/evaluate_rows", method="POST") diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 38c5869c2..2c2cea53c 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -7,11 +7,17 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job -from .....apis.eval.eval import BenchmarkEvalTaskConfig +from .....apis.eval.eval import ( + AppEvalTaskConfig, + BenchmarkEvalTaskConfig, + Eval, + EvalTaskConfig, + EvaluateResponse, + JobStatus, +) from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring @@ -88,21 +94,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def run_benchmark_eval( + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: raise NotImplementedError("Benchmark eval is not implemented yet") async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: - dataset_id = eval_task_def.dataset_id - candidate = eval_task_config.eval_candidate - scoring_functions = eval_task_def.scoring_functions + dataset_id = task.dataset_id + candidate = task_config.eval_candidate + scoring_functions = task.scoring_functions await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( @@ -112,7 +118,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): res = await self.evaluate_rows( input_rows=all_rows.rows, scoring_functions=scoring_functions, - eval_task_config=eval_task_config, + eval_task_config=task_config, ) # TODO: currently needs to wait for generation before returning @@ -179,8 +185,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): for input_r, generated_r in zip(input_rows, generations) ] + if ( + eval_task_config.type == "app" + and eval_task_config.scoring_params is not None + ): + scoring_functions_dict = { + scoring_fn_id: eval_task_config.scoring_params.get(scoring_fn_id, None) + for scoring_fn_id in scoring_functions + } + else: + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } + score_response = await self.scoring_api.score( - input_rows=score_input_rows, scoring_functions=scoring_functions + input_rows=score_input_rows, scoring_functions=scoring_functions_dict ) return EvaluateResponse(generations=generations, scores=score_response.results) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 56cf13503..d6eb3ae96 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -96,8 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -108,7 +107,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): res = await self.score( input_rows=all_rows.rows, scoring_functions=scoring_functions, - scoring_params=scoring_params, ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -122,17 +120,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} - for scoring_fn_id in scoring_functions: + for scoring_fn_id in scoring_functions.keys(): if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - scoring_fn_params = None - if scoring_params is not None: - scoring_fn_params = scoring_params.get(scoring_fn_id, None) + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 6fe6c07fb..0243ca64c 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -76,13 +76,13 @@ class Testeval: ] response = await eval_impl.run_eval( - eval_task_def=EvalTaskDef( + task=EvalTaskDef( # NOTE: this is needed to make the router work for all app evals identifier="meta-reference::app_eval", dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, ), - eval_task_config=AppEvalTaskConfig( + task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 7a36a96f9..3c1b6554f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -44,10 +44,10 @@ class TestScoring: ) assert len(rows.rows) == 3 - scoring_functions = [ - "meta-reference::llm_as_judge_8b_correctness", - "meta-reference::equality", - ] + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": None, + "meta-reference::equality": None, + } response = await scoring_impl.score( input_rows=rows.rows, scoring_functions=scoring_functions, @@ -83,7 +83,7 @@ class TestScoring: ) assert len(rows.rows) == 3 - params = { + scoring_functions = { "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", @@ -91,13 +91,9 @@ class TestScoring: ) } - scoring_functions = [ - "meta-reference::llm_as_judge_8b_correctness", - ] response = await scoring_impl.score( input_rows=rows.rows, scoring_functions=scoring_functions, - scoring_params=params, ) assert len(response.results) == len(scoring_functions) for x in scoring_functions: @@ -108,7 +104,6 @@ class TestScoring: response = await scoring_impl.score_batch( dataset_id="test_dataset", scoring_functions=scoring_functions, - scoring_params=params, ) assert len(response.results) == len(scoring_functions) for x in scoring_functions: