From f05db9a25c9ba27206ab6d35d3a8db40634875d7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 14:30:46 -0800 Subject: [PATCH] add eval_id for jobs --- llama_stack/apis/eval/eval.py | 12 +++------ llama_stack/distribution/routers/routers.py | 24 ++++++++---------- .../inline/meta_reference/eval/eval.py | 25 +++++++------------ llama_stack/providers/tests/eval/test_eval.py | 10 +++++--- 4 files changed, 30 insertions(+), 41 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 85e6666b4..6aa4cae34 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -85,21 +85,17 @@ class Eval(Protocol): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, + task_config: EvalTaskConfig, eval_task_id: Optional[str] = None, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") async def job_status( - self, job_id: str, eval_task_id: Optional[str] = None + self, job_id: str, eval_task_id: str ) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel( - self, job_id: str, eval_task_id: Optional[str] = None - ) -> None: ... + async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result( - self, job_id: str, eval_task_id: Optional[str] = None - ) -> EvaluateResponse: ... + async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index e59cc4ec7..06d50bd65 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -290,7 +290,7 @@ class EvalRouter(Eval): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, + task_config: EvalTaskConfig, eval_task_id: Optional[str] = None, ) -> EvaluateResponse: # NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task @@ -300,36 +300,32 @@ class EvalRouter(Eval): return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows( input_rows=input_rows, scoring_functions=scoring_functions, - eval_task_config=eval_task_config, + task_config=task_config, ) async def job_status( self, job_id: str, - eval_task_id: Optional[str] = None, + eval_task_id: str, ) -> Optional[JobStatus]: - if eval_task_id is None: - eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER return await self.routing_table.get_provider_impl(eval_task_id).job_status( - job_id + job_id, eval_task_id ) async def job_cancel( self, job_id: str, - eval_task_id: Optional[str] = None, + eval_task_id: str, ) -> None: - if eval_task_id is None: - eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER - await self.routing_table.get_provider_impl(eval_task_id).job_cancel(job_id) + await self.routing_table.get_provider_impl(eval_task_id).job_cancel( + job_id, eval_task_id + ) async def job_result( self, job_id: str, - eval_task_id: Optional[str] = None, + eval_task_id: str, ) -> EvaluateResponse: - if eval_task_id is None: - eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER return await self.routing_table.get_provider_impl(eval_task_id).job_result( - job_id + job_id, eval_task_id ) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index aaf146c18..a9a1978e9 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -122,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): res = await self.evaluate_rows( input_rows=all_rows.rows, scoring_functions=scoring_functions, - eval_task_config=task_config, + task_config=task_config, ) # TODO: currently needs to wait for generation before returning @@ -135,10 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, + task_config: EvalTaskConfig, eval_task_id: Optional[str] = None, ) -> EvaluateResponse: - candidate = eval_task_config.eval_candidate + candidate = task_config.eval_candidate if candidate.type == "agent": raise NotImplementedError( "Evaluation with generation has not been implemented for agents" @@ -190,12 +190,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): for input_r, generated_r in zip(input_rows, generations) ] - if ( - eval_task_config.type == "app" - and eval_task_config.scoring_params is not None - ): + if task_config.type == "app" and task_config.scoring_params is not None: scoring_functions_dict = { - scoring_fn_id: eval_task_config.scoring_params.get(scoring_fn_id, None) + scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) for scoring_fn_id in scoring_functions } else: @@ -209,21 +206,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status( - self, job_id: str, eval_task_id: Optional[str] = None - ) -> Optional[JobStatus]: + async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str, eval_task_id: Optional[str] = None) -> None: + async def job_cancel(self, job_id: str, eval_task_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result( - self, job_id: str, eval_task_id: Optional[str] = None - ) -> EvaluateResponse: - status = await self.job_status(job_id) + async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: + status = await self.job_status(job_id, eval_task_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 0243ca64c..d97f74ec4 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -52,7 +52,7 @@ class Testeval: response = await eval_impl.evaluate_rows( input_rows=rows.rows, scoring_functions=scoring_functions, - eval_task_config=AppEvalTaskConfig( + task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), @@ -90,9 +90,13 @@ class Testeval: ), ) assert response.job_id == "0" - job_status = await eval_impl.job_status(response.job_id) + job_status = await eval_impl.job_status( + response.job_id, "meta-reference::app_eval" + ) assert job_status and job_status.value == "completed" - eval_response = await eval_impl.job_result(response.job_id) + eval_response = await eval_impl.job_result( + response.job_id, "meta-reference::app_eval" + ) assert eval_response is not None assert len(eval_response.generations) == 5