diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py index 539b48521..0007066aa 100644 --- a/llama_stack/apis/eval_tasks/eval_tasks.py +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -35,9 +35,9 @@ class EvalTasks(Protocol): async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... @webmethod(route="/eval_tasks/get", method="GET") - async def get_eval_tasks(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... @webmethod(route="/eval_tasks/register", method="POST") - async def register_eval_tasks( - self, function_def: EvalTaskDefWithProvider + async def register_eval_task( + self, eval_task_def: EvalTaskDefWithProvider ) -> None: ... diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 2149162a6..3fc3b2d5d 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.scoring_functions, router_api=Api.scoring, ), + AutoRoutedApiInfo( + routing_table_api=Api.eval_tasks, + router_api=Api.eval, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 96b4b81e6..7a8d1dfee 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -15,6 +15,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -46,6 +47,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring: Scoring, Api.scoring_functions: ScoringFunctions, Api.eval: Eval, + Api.eval_tasks: EvalTasks, } @@ -56,6 +58,7 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.safety: (ShieldsProtocolPrivate, Shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets), Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), + Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index b3ebd1368..57e81ac30 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,6 +12,7 @@ from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( DatasetsRoutingTable, + EvalTasksRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, ScoringFunctionsRoutingTable, @@ -31,6 +32,7 @@ async def get_routing_table_impl( "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, + "eval_tasks": EvalTasksRoutingTable, } if api.value not in api_to_tables: @@ -44,6 +46,7 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: from .routers import ( DatasetIORouter, + EvalRouter, InferenceRouter, MemoryRouter, SafetyRouter, @@ -56,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "safety": SafetyRouter, "datasetio": DatasetIORouter, "scoring": ScoringRouter, + "eval": EvalRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5b8274245..f77e03928 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.eval import * # noqa: F403 class MemoryRouter(Memory): @@ -252,3 +253,49 @@ class ScoringRouter(Scoring): res.update(score_response.results) return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_benchmark_eval( + self, + benchmark_id: str, + eval_task_config: BenchmarkEvalTaskConfig, + ) -> Job: + pass + + async def run_eval( + self, + eval_task_def: EvalTaskDef, + eval_task_config: EvalTaskConfig, + ) -> Job: + pass + + @webmethod(route="/eval/evaluate_rows", method="POST") + async def evaluate_rows( + self, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + eval_task_config: EvalTaskConfig, # type: ignore + ) -> EvaluateResponse: + pass + + async def job_status(self, job_id: str) -> Optional[JobStatus]: + pass + + async def job_cancel(self, job_id: str) -> None: + pass + + async def job_result(self, job_id: str) -> EvaluateResponse: + pass diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4c5bdf654..a676b5fef 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -12,6 +12,8 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 + from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 @@ -40,6 +42,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_dataset(obj) elif api == Api.scoring: await p.register_scoring_function(obj) + elif api == Api.eval: + await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -103,6 +107,11 @@ class CommonRoutingTableImpl(RoutingTable): scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + elif api == Api.eval: + p.eval_task_store = self + eval_tasks = await p.list_eval_tasks() + await add_objects(eval_tasks, pid, EvalTaskDefWithProvider) + async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() @@ -121,6 +130,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): return ("Scoring", "scoring_function") + elif isinstance(self, EvalTasksRoutingTable): + return ("Eval", "eval_task") else: raise ValueError("Unknown routing table type") @@ -246,7 +257,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): await self.register_object(dataset_def) -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: return await self.get_all_with_type("scoring_fn") @@ -259,3 +270,14 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): self, function_def: ScoringFnDefWithProvider ) -> None: await self.register_object(function_def) + + +class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): + async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: + return await self.get_all_with_type("eval_task") + + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: + return await self.get_object_by_identifier(name) + + async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: + await self.register_object(eval_task_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 69255fc5f..50633a355 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef from llama_stack.apis.scoring_functions import ScoringFnDef @@ -34,6 +35,7 @@ class Api(Enum): memory_banks = "memory_banks" datasets = "datasets" scoring_functions = "scoring_functions" + eval_tasks = "eval_tasks" # built-in API inspect = "inspect" @@ -69,6 +71,12 @@ class ScoringFunctionsProtocolPrivate(Protocol): async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... +class EvalTasksProtocolPrivate(Protocol): + async def list_eval_tasks(self) -> List[EvalTaskDef]: ... + + async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 3aec6170f..28420ee35 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -6,13 +6,16 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 +from .....apis.common.job_types import Job +from .....apis.eval.eval import BenchmarkEvalTaskConfig from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.common.job_types import Job from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus +from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring +from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig @@ -25,7 +28,7 @@ class ColumnName(Enum): generated_answer = "generated_answer" -class MetaReferenceEvalImpl(Eval): +class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): def __init__( self, config: MetaReferenceEvalConfig, @@ -47,6 +50,10 @@ class MetaReferenceEvalImpl(Eval): async def shutdown(self) -> None: ... + async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: + print("HHHH") + return [] + async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: @@ -70,12 +77,22 @@ class MetaReferenceEvalImpl(Eval): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def evaluate_batch( + async def run_benchmark_eval( self, - dataset_id: str, - candidate: EvalCandidate, - scoring_functions: List[str], + benchmark_id: str, + eval_task_config: BenchmarkEvalTaskConfig, ) -> Job: + raise NotImplementedError("Benchmark eval is not implemented yet") + + async def run_eval( + self, + eval_task_def: EvalTaskDef, + eval_task_config: EvalTaskConfig, + ) -> Job: + dataset_id = eval_task_def.dataset_id + candidate = eval_task_config.eval_candidate + scoring_functions = eval_task_def.scoring_functions + await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, @@ -93,12 +110,13 @@ class MetaReferenceEvalImpl(Eval): self.jobs[job_id] = res return Job(job_id=job_id) - async def evaluate( + async def evaluate_rows( self, input_rows: List[Dict[str, Any]], - candidate: EvalCandidate, scoring_functions: List[str], + eval_task_config: EvalTaskConfig, ) -> EvaluateResponse: + candidate = eval_task_config.eval_candidate if candidate.type == "agent": raise NotImplementedError( "Evaluation with generation has not been implemented for agents" diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 459b58f22..fc15d48fc 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -151,4 +151,5 @@ pytest_plugins = [ "llama_stack.providers.tests.agents.fixtures", "llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.scoring.fixtures", + "llama_stack.providers.tests.eval.fixtures", ] diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py new file mode 100644 index 000000000..064feb611 --- /dev/null +++ b/llama_stack/providers/tests/eval/conftest.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from ..scoring.fixtures import SCORING_FIXTURES +from .fixtures import EVAL_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_eval_fireworks_inference", + marks=pytest.mark.meta_reference_eval_fireworks_inference, + ), + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_eval_together_inference", + marks=pytest.mark.meta_reference_eval_together_inference, + ), +] + + +def pytest_configure(config): + for fixture_name in [ + "meta_reference_eval_fireworks_inference", + "meta_reference_eval_together_inference", + ]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "eval_stack" in metafunc.fixturenames: + available_fixtures = { + "eval": EVAL_FIXTURES, + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("eval_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py new file mode 100644 index 000000000..22181f3b2 --- /dev/null +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def eval_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def eval_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +EVAL_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def eval_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["datasetio", "eval", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.eval, Api.datasetio, Api.inference, Api.scoring], + providers, + provider_data, + ) + + return ( + impls[Api.eval], + impls[Api.eval_tasks], + impls[Api.scoring], + impls[Api.scoring_functions], + impls[Api.datasetio], + impls[Api.datasets], + ) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py new file mode 100644 index 000000000..cc14ccd1d --- /dev/null +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +# How to run this test: +# +# pytest llama_stack/providers/tests/eval/test_eval.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings + + +class Testeval: + @pytest.mark.asyncio + async def test_eval_tasks_list(self, eval_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, eval_tasks_impl, _, _, _, _ = eval_stack + response = await eval_tasks_impl.list_eval_tasks() + assert isinstance(response, list) + print(response) diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index 698d4a60a..ee578f9b3 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -21,12 +21,24 @@ DEFAULT_PROVIDER_COMBINATIONS = [ }, id="meta_reference_scoring_fireworks_inference", marks=pytest.mark.meta_reference_scoring_fireworks_inference, - ) + ), + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_scoring_together_inference", + marks=pytest.mark.meta_reference_scoring_together_inference, + ), ] def pytest_configure(config): - for fixture_name in ["meta_reference_scoring_fireworks_inference"]: + for fixture_name in [ + "meta_reference_scoring_fireworks_inference", + "meta_reference_scoring_together_inference", + ]: config.addinivalue_line( "markers", f"{fixture_name}: marks tests as {fixture_name} specific",