From 8fbbea8c432533bf895ad58ec20fcaa81aeebc3a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 17 Nov 2024 20:18:26 -0800 Subject: [PATCH] refactor --- .../inline/eval/meta_reference/__init__.py | 1 + .../inline/eval/meta_reference/eval.py | 40 ++++++++++++++----- llama_stack/providers/registry/eval.py | 1 + 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index fb285c668..56c115322 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -22,6 +22,7 @@ async def get_provider_impl( deps[Api.datasets], deps[Api.scoring], deps[Api.inference], + deps[Api.agents], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index aa22ad31b..8b13a1f28 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -9,6 +9,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask @@ -39,12 +40,14 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): datasets_api: Datasets, scoring_api: Scoring, inference_api: Inference, + agent_api: Agents, ) -> None: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api self.scoring_api = scoring_api self.inference_api = inference_api + self.agent_api = agent_api # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} @@ -126,18 +129,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.jobs[job_id] = res return Job(job_id=job_id) - async def evaluate_rows( - self, - task_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - task_config: EvalTaskConfig, - ) -> EvaluateResponse: + async def _run_agent_generation( + self, task_config: EvalTaskConfig + ) -> List[Dict[str, Any]]: + pass + + async def _run_model_generation( + self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig + ) -> List[Dict[str, Any]]: candidate = task_config.eval_candidate - if candidate.type == "agent": - raise NotImplementedError( - "Evaluation with generation has not been implemented for agents" - ) assert ( candidate.sampling_params.max_tokens is not None ), "SamplingParams.max_tokens must be provided" @@ -179,6 +179,24 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): else: raise ValueError("Invalid input row") + return generations + + async def evaluate_rows( + self, + task_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: EvalTaskConfig, + ) -> EvaluateResponse: + candidate = task_config.eval_candidate + if candidate.type == "agent": + raise NotImplementedError( + "Evaluation with generation has not been implemented for agents" + ) + + if candidate.type == "model": + generations = await self._run_model_generation(input_rows, task_config) + # scoring with generated_answer score_input_rows = [ input_r | generated_r diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 3fa5c75e0..718c7eae5 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]: Api.datasets, Api.scoring, Api.inference, + Api.agents, ], ), ]