diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 0007066aa..870673e58 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -7,12 +7,14 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field +from pydantic import Field + +from llama_stack.apis.resource import Resource @json_schema_type -class EvalTaskDef(BaseModel): - identifier: str +class EvalTask(Resource): + type: Literal["eval_task"] = "eval_task" dataset_id: str scoring_functions: List[str] metadata: Dict[str, Any] = Field( @@ -21,23 +23,21 @@ class EvalTaskDef(BaseModel): ) -@json_schema_type -class EvalTaskDefWithProvider(EvalTaskDef): - type: Literal["eval_task"] = "eval_task" - provider_id: str = Field( - description="ID of the provider which serves this dataset", - ) - - @runtime_checkable class EvalTasks(Protocol): @webmethod(route="/eval_tasks/list", method="GET") - async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... + async def list_eval_tasks(self) -> List[EvalTask]: ... @webmethod(route="/eval_tasks/get", method="GET") - async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + async def get_eval_task(self, name: str) -> Optional[EvalTask]: ... @webmethod(route="/eval_tasks/register", method="POST") async def register_eval_task( - self, eval_task_def: EvalTaskDefWithProvider + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + provider_eval_task_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ad246789e..b0091f5a0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.eval: p.eval_task_store = self - eval_tasks = await p.list_eval_tasks() - await add_objects(eval_tasks, pid, EvalTaskDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): - async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: + async def list_eval_tasks(self) -> List[EvalTask]: return await self.get_all_with_type("eval_task") - async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: + async def get_eval_task(self, name: str) -> Optional[EvalTask]: return await self.get_object_by_identifier(name) - async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: - await self.register_object(eval_task_def) + async def register_eval_task( + self, + eval_task_id: str, + dataset_id: str, + scoring_functions: List[str], + metadata: Optional[Dict[str, Any]] = None, + provider_eval_task_id: Optional[str] = None, + provider_id: Optional[str] = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_eval_task_id is None: + provider_eval_task_id = eval_task_id + eval_task = EvalTask( + identifier=eval_task_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_eval_task_id, + ) + await self.register_object(eval_task) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index aeb0be742..f065d4f33 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import Dataset -from llama_stack.apis.eval_tasks import EvalTaskDef +from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFnDef @@ -67,9 +67,7 @@ class ScoringFunctionsProtocolPrivate(Protocol): class EvalTasksProtocolPrivate(Protocol): - async def list_eval_tasks(self) -> List[EvalTaskDef]: ... - - async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ... + async def register_eval_task(self, eval_task: EvalTask) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index df642f33b..ba2fc7c95 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -11,7 +11,7 @@ from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatu from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval_tasks import EvalTaskDef +from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate @@ -53,15 +53,12 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): async def shutdown(self) -> None: ... - async def register_eval_task(self, task_def: EvalTaskDef) -> None: + async def register_eval_task(self, task_def: EvalTask) -> None: self.eval_tasks[task_def.identifier] = task_def - async def list_eval_tasks(self) -> List[EvalTaskDef]: - return list(self.eval_tasks.values()) - async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") expected_schemas = [ @@ -77,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): }, ] - if dataset_def.dataset_schema not in expected_schemas: + if dataset_def.schema not in expected_schemas: raise ValueError( f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 9f14c61ef..92c4d0331 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -11,12 +11,9 @@ from llama_models.llama3.api import SamplingParams, URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType -from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider - from llama_stack.apis.eval.eval import ( AppEvalTaskConfig, BenchmarkEvalTaskConfig, - EvalTaskDefWithProvider, ModelCandidate, ) from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams @@ -70,13 +67,11 @@ class Testeval: "meta-reference::equality", ] task_id = "meta-reference::app_eval" - task_def = EvalTaskDefWithProvider( - identifier=task_id, + await eval_tasks_impl.register_eval_task( + eval_task_id=task_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, - provider_id="meta-reference", ) - await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.evaluate_rows( task_id=task_id, input_rows=rows.rows, @@ -125,13 +120,11 @@ class Testeval: ] task_id = "meta-reference::app_eval-2" - task_def = EvalTaskDefWithProvider( - identifier=task_id, + await eval_tasks_impl.register_eval_task( + eval_task_id=task_id, dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, - provider_id="meta-reference", ) - await eval_tasks_impl.register_eval_task(task_def) response = await eval_impl.run_eval( task_id=task_id, task_config=AppEvalTaskConfig( @@ -169,35 +162,29 @@ class Testeval: pytest.skip( "Only huggingface provider supports pre-registered remote datasets" ) - # register dataset - mmlu = DatasetDefWithProvider( - identifier="mmlu", - url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), - dataset_schema={ + + await datasets_impl.register_dataset( + dataset_id="mmlu", + schema={ "input_query": StringType(), "expected_answer": StringType(), "chat_completion_input": ChatCompletionInputType(), }, + url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), metadata={ "path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train", }, - provider_id="", ) - await datasets_impl.register_dataset(mmlu) - # register eval task - meta_reference_mmlu = EvalTaskDefWithProvider( - identifier="meta-reference-mmlu", + await eval_tasks_impl.register_eval_task( + eval_task_id="meta-reference-mmlu", dataset_id="mmlu", scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"], - provider_id="", ) - await eval_tasks_impl.register_eval_task(meta_reference_mmlu) - # list benchmarks response = await eval_tasks_impl.list_eval_tasks() assert len(response) > 0