diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 0007066aa..cfa00034c 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -7,12 +7,14 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field +from pydantic import Field + +from llama_stack.apis.resource import Resource @json_schema_type -class EvalTaskDef(BaseModel): - identifier: str +class EvalTask(Resource): + type: Literal["eval_task"] = "eval_task" dataset_id: str scoring_functions: List[str] metadata: Dict[str, Any] = Field( @@ -21,23 +23,21 @@ class EvalTaskDef(BaseModel): ) -@json_schema_type -class EvalTaskDefWithProvider(EvalTaskDef): - type: Literal["eval_task"] = "eval_task" - provider_id: str = Field( - description="ID of the provider which serves this dataset", - ) - - @runtime_checkable class EvalTasks(Protocol): @webmethod(route="/eval_tasks/list", method="GET") - async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... + async def list_eval_tasks(self) -> List[EvalTask]: ... @webmethod(route="/eval_tasks/get", method="GET") - async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + async def get_eval_task(self, name: str) -> Optional[EvalTask]: ... @webmethod(route="/eval_tasks/register", method="POST") async def register_eval_task( - self, eval_task_def: EvalTaskDefWithProvider + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + provider_id: Optional[str] = None, + provider_eval_task_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ad246789e..2ad9bc282 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.eval: p.eval_task_store = self - eval_tasks = await p.list_eval_tasks() - await add_objects(eval_tasks, pid, EvalTaskDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): - async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: + async def list_eval_tasks(self) -> List[EvalTask]: return await self.get_all_with_type("eval_task") - async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: + async def get_eval_task(self, name: str) -> Optional[EvalTask]: return await self.get_object_by_identifier(name) - async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: - await self.register_object(eval_task_def) + async def register_eval_task( + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + metadata: Optional[Dict[str, Any]] = None, + provider_id: Optional[str] = None, + provider_eval_task_id: Optional[str] = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_eval_task_id is None: + provider_eval_task_id = eval_task_id + eval_task = EvalTask( + identifier=eval_task_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_eval_task_id, + ) + await self.register_object(eval_task) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index aeb0be742..84771190e 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -12,8 +12,8 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import Dataset -from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks.memory_banks import MemoryBank +from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.shields import Shield @@ -67,9 +67,9 @@ class ScoringFunctionsProtocolPrivate(Protocol): class EvalTasksProtocolPrivate(Protocol): - async def list_eval_tasks(self) -> List[EvalTaskDef]: ... + async def list_eval_tasks(self) -> List[EvalTask]: ... - async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ... + async def register_eval_task(self, eval_task: EvalTask) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index df642f33b..23d3cf6dc 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -11,7 +11,7 @@ from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatu from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval_tasks import EvalTaskDef +from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate @@ -53,10 +53,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): async def shutdown(self) -> None: ... - async def register_eval_task(self, task_def: EvalTaskDef) -> None: + async def register_eval_task(self, task_def: EvalTask) -> None: self.eval_tasks[task_def.identifier] = task_def - async def list_eval_tasks(self) -> List[EvalTaskDef]: + async def list_eval_tasks(self) -> List[EvalTask]: return list(self.eval_tasks.values()) async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 9f14c61ef..0728ce5d9 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -16,7 +16,6 @@ from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider from llama_stack.apis.eval.eval import ( AppEvalTaskConfig, BenchmarkEvalTaskConfig, - EvalTaskDefWithProvider, ModelCandidate, ) from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams @@ -70,13 +69,11 @@ class Testeval: "meta-reference::equality", ] task_id = "meta-reference::app_eval" - task_def = EvalTaskDefWithProvider( - identifier=task_id, + await eval_tasks_impl.register_eval_task( + eval_task_id=task_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, - provider_id="meta-reference", ) - await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.evaluate_rows( task_id=task_id, input_rows=rows.rows, @@ -125,13 +122,11 @@ class Testeval: ] task_id = "meta-reference::app_eval-2" - task_def = EvalTaskDefWithProvider( - identifier=task_id, + await eval_tasks_impl.register_eval_task( + eval_task_id=task_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, - provider_id="meta-reference", ) - await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.run_eval( task_id=task_id, task_config=AppEvalTaskConfig( @@ -189,15 +184,12 @@ class Testeval: await datasets_impl.register_dataset(mmlu) # register eval task - meta_reference_mmlu = EvalTaskDefWithProvider( - identifier="meta-reference-mmlu", + await eval_tasks_impl.register_eval_task( + eval_task_id="meta-reference-mmlu", dataset_id="mmlu", scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"], - provider_id="", ) - await eval_tasks_impl.register_eval_task(meta_reference_mmlu) - # list benchmarks response = await eval_tasks_impl.list_eval_tasks() assert len(response) > 0