fix precommit

This commit is contained in:
Xi Yan 2025-03-23 16:42:50 -07:00
parent 81bc051411
commit 97e7717c9b
3 changed files with 71 additions and 3 deletions

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -35,6 +36,7 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate,
@ -71,6 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.evaluation: Evaluation,
}
@ -81,6 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.evaluation: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}

View file

@ -46,6 +46,7 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvaluationRouter,
InferenceRouter,
SafetyRouter,
ToolRuntimeRouter,
@ -58,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"tool_runtime": ToolRuntimeRouter,
"evaluation": EvaluationRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},

View file

@ -7,13 +7,21 @@
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.datasets import Dataset, DatasetPurpose, DataSource
from llama_stack.apis.evaluation import (
Evaluation,
EvaluationCandidate,
EvaluationJob,
EvaluationResponse,
EvaluationTask,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -474,11 +482,11 @@ class DatasetIORouter(DatasetIO):
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
) -> Dataset:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
return await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
@ -573,3 +581,57 @@ class ToolRuntimeRouter(ToolRuntime):
) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
class EvaluationRouter(Evaluation):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvaluationRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvaluationRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvaluationRouter.shutdown")
pass
async def register_benchmark(
self,
dataset_id: str,
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Benchmark:
logger.debug(
f"EvaluationRouter.register_benchmark: {benchmark_id=} {dataset_id=} {grader_ids=} {metadata=}",
)
return await self.routing_table.register_benchmark(
benchmark_id=benchmark_id,
dataset_id=dataset_id,
grader_ids=grader_ids,
metadata=metadata,
)
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
raise NotImplementedError("Run is not implemented yet")
async def run_sync(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
raise NotImplementedError("Run sync is not implemented yet")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
raise NotImplementedError("Grade is not implemented yet")
async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse:
raise NotImplementedError("Grade sync is not implemented yet")