From 4a64f98c8226955f7fa471aff3bd1756154da8e4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 5 Nov 2024 16:54:31 -0800 Subject: [PATCH] separate benchmark / app eval --- llama_stack/apis/eval/eval.py | 26 ++++++++--------------- llama_stack/apis/eval_tasks/eval_tasks.py | 2 +- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index df56f5325..9cb00145e 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -38,47 +38,39 @@ EvalCandidate = Annotated[ @json_schema_type class BenchmarkEvalTaskConfig(BaseModel): - type: Literal["benchmark"] = "benchmark" eval_candidate: EvalCandidate # type: ignore @json_schema_type class AppEvalTaskConfig(BaseModel): - type: Literal["app"] = "app" eval_candidate: EvalCandidate # type: ignore scoring_functions_params: Dict[str, ScoringFnParams] = Field( # type: ignore description="Map between scoring function id and parameters", default_factory=dict, ) - # we could optinally add any GenEval specific dataset config here - - -EvalTaskConfig = Annotated[ - Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") -] + # we could optinally add any specific dataset config here @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] - # each key in the dict is a scoring function name scores: Dict[str, ScoringResult] class Eval(Protocol): - @webmethod(route="/eval/evaluate_batch", method="POST") - async def evaluate_task( + @webmethod(route="/eval/run_benchmark", method="POST") + async def run_benchmark( self, - eval_task_id: str, - eval_task_config: EvalTaskConfig, # type: ignore + benchmark_id: str, + eval_task_config: BenchmarkEvalTaskConfig, # type: ignore ) -> Job: ... - @webmethod(route="/eval/evaluate_batch", method="POST") - async def evaluate_batch( + @webmethod(route="/eval/run_app_eval", method="POST") + async def run_app_eval( self, - eval_task_def: Union[str, EvalTaskDef], # type: ignore - eval_task_config: EvalTaskConfig, # type: ignore + eval_task_def: EvalTaskDef, # type: ignore + eval_task_config: AppEvalTaskConfig, # type: ignore ) -> Job: ... @webmethod(route="/eval/evaluate", method="POST") diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 0499d5cf9..539b48521 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -17,7 +17,7 @@ class EvalTaskDef(BaseModel): scoring_functions: List[str] metadata: Dict[str, Any] = Field( default_factory=dict, - description="Metadata for this evaluation task (e.g. from GECO)", + description="Metadata for this evaluation task", )