diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3a6140478..8f07b7d05 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.evaluation import Evaluation from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models @@ -35,6 +36,7 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( Api, + BenchmarksProtocolPrivate, DatasetsProtocolPrivate, InlineProviderSpec, ModelsProtocolPrivate, @@ -71,6 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.post_training: PostTraining, Api.tool_groups: ToolGroups, Api.tool_runtime: ToolRuntime, + Api.evaluation: Evaluation, } @@ -81,6 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), + Api.evaluation: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a9a4f87c8..69b384bc4 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -46,6 +46,7 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, + EvaluationRouter, InferenceRouter, SafetyRouter, ToolRuntimeRouter, @@ -58,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict "safety": SafetyRouter, "datasetio": DatasetIORouter, "tool_runtime": ToolRuntimeRouter, + "evaluation": EvaluationRouter, } api_to_deps = { "inference": {"telemetry": Api.telemetry}, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6c77d09e8..17ef1626f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,13 +7,21 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.common.content_types import ( URL, InterleavedContent, InterleavedContentItem, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse -from llama_stack.apis.datasets import DatasetPurpose, DataSource +from llama_stack.apis.datasets import Dataset, DatasetPurpose, DataSource +from llama_stack.apis.evaluation import ( + Evaluation, + EvaluationCandidate, + EvaluationJob, + EvaluationResponse, + EvaluationTask, +) from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -474,11 +482,11 @@ class DatasetIORouter(DatasetIO): source: DataSource, metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, - ) -> None: + ) -> Dataset: logger.debug( f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", ) - await self.routing_table.register_dataset( + return await self.routing_table.register_dataset( purpose=purpose, source=source, metadata=metadata, @@ -573,3 +581,57 @@ class ToolRuntimeRouter(ToolRuntime): ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + + +class EvaluationRouter(Evaluation): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing EvaluationRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("EvaluationRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("EvaluationRouter.shutdown") + pass + + async def register_benchmark( + self, + dataset_id: str, + grader_ids: List[str], + benchmark_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Benchmark: + logger.debug( + f"EvaluationRouter.register_benchmark: {benchmark_id=} {dataset_id=} {grader_ids=} {metadata=}", + ) + return await self.routing_table.register_benchmark( + benchmark_id=benchmark_id, + dataset_id=dataset_id, + grader_ids=grader_ids, + metadata=metadata, + ) + + async def run( + self, + task: EvaluationTask, + candidate: EvaluationCandidate, + ) -> EvaluationJob: + raise NotImplementedError("Run is not implemented yet") + + async def run_sync( + self, + task: EvaluationTask, + candidate: EvaluationCandidate, + ) -> EvaluationResponse: + raise NotImplementedError("Run sync is not implemented yet") + + async def grade(self, task: EvaluationTask) -> EvaluationJob: + raise NotImplementedError("Grade is not implemented yet") + + async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse: + raise NotImplementedError("Grade sync is not implemented yet")