From ea80f623fbbc1940a9fabd1694d89a642bca7376 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 14:19:33 -0800 Subject: [PATCH] add default task_eval_id for routing --- llama_stack/apis/eval/eval.py | 13 +++-- llama_stack/distribution/routers/routers.py | 54 ++++++++++++------- .../inline/meta_reference/eval/eval.py | 15 ++++-- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 9613ffc58..85e6666b4 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -86,13 +86,20 @@ class Eval(Protocol): input_rows: List[Dict[str, Any]], scoring_functions: List[str], eval_task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> Optional[JobStatus]: ... + async def job_status( + self, job_id: str, eval_task_id: Optional[str] = None + ) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel(self, job_id: str) -> None: ... + async def job_cancel( + self, job_id: str, eval_task_id: Optional[str] = None + ) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str) -> EvaluateResponse: ... + async def job_result( + self, job_id: str, eval_task_id: Optional[str] = None + ) -> EvaluateResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index ca9da7368..e59cc4ec7 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -280,10 +280,7 @@ class EvalRouter(Eval): task: EvalTaskDef, task_config: AppEvalTaskConfig, ) -> Job: - # NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).run_eval( + return await self.routing_table.get_provider_impl(task.identifier).run_eval( task=task, task_config=task_config, ) @@ -293,29 +290,46 @@ class EvalRouter(Eval): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, # type: ignore + eval_task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: # NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task # We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).evaluate_rows( + if eval_task_id is None: + eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER + return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows( input_rows=input_rows, scoring_functions=scoring_functions, eval_task_config=eval_task_config, ) - async def job_status(self, job_id: str) -> Optional[JobStatus]: - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_status(job_id) + async def job_status( + self, + job_id: str, + eval_task_id: Optional[str] = None, + ) -> Optional[JobStatus]: + if eval_task_id is None: + eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER + return await self.routing_table.get_provider_impl(eval_task_id).job_status( + job_id + ) - async def job_cancel(self, job_id: str) -> None: - await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_cancel(job_id) + async def job_cancel( + self, + job_id: str, + eval_task_id: Optional[str] = None, + ) -> None: + if eval_task_id is None: + eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER + await self.routing_table.get_provider_impl(eval_task_id).job_cancel(job_id) - async def job_result(self, job_id: str) -> EvaluateResponse: - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_result(job_id) + async def job_result( + self, + job_id: str, + eval_task_id: Optional[str] = None, + ) -> EvaluateResponse: + if eval_task_id is None: + eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER + return await self.routing_table.get_provider_impl(eval_task_id).job_result( + job_id + ) diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 2c2cea53c..aaf146c18 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -25,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig + +# NOTE: this is the default eval task identifier for app eval +# it is used to make the router work for all app evals +# For app eval using other eval providers, the eval task identifier will be different DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval" @@ -132,6 +136,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): input_rows: List[Dict[str, Any]], scoring_functions: List[str], eval_task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: candidate = eval_task_config.eval_candidate if candidate.type == "agent": @@ -204,16 +209,20 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str) -> Optional[JobStatus]: + async def job_status( + self, job_id: str, eval_task_id: Optional[str] = None + ) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str) -> None: + async def job_cancel(self, job_id: str, eval_task_id: Optional[str] = None) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str) -> EvaluateResponse: + async def job_result( + self, job_id: str, eval_task_id: Optional[str] = None + ) -> EvaluateResponse: status = await self.job_status(job_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}")