From 081ec3131d234b37b67bd56f88afca03f58d6400 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 14:43:41 -0700 Subject: [PATCH] fix router --- llama_stack/distribution/routers/routers.py | 30 ++++-- .../distribution/routers/routing_tables.py | 98 ++++++++++++++----- 2 files changed, 97 insertions(+), 31 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5acd945fe..879fc924b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -13,6 +13,7 @@ from llama_stack.apis.common.content_types import ( URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse +from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.eval import ( BenchmarkConfig, Eval, @@ -537,21 +538,36 @@ class DatasetIORouter(DatasetIO): logger.debug("DatasetIORouter.shutdown") pass + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: Optional[Dict[str, Any]] = None, + dataset_id: Optional[str] = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + async def iterrows( self, dataset_id: str, - rows_in_page: int, - page_token: Optional[str] = None, - filter_condition: Optional[str] = None, + start_index: Optional[int] = None, + limit: Optional[int] = None, ) -> IterrowsResponse: logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, rows_in_page={rows_in_page}", + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) return await self.routing_table.get_provider_impl(dataset_id).iterrows( dataset_id=dataset_id, - rows_in_page=rows_in_page, - page_token=page_token, - filter_condition=filter_condition, + start_index=start_index, + limit=limit, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 1be43ec8b..ec7abba90 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import logging +import uuid from typing import Any, Dict, List, Optional from pydantic import TypeAdapter @@ -12,7 +13,14 @@ from pydantic import TypeAdapter from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, +) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( @@ -97,7 +105,9 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -132,7 +142,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -170,7 +182,9 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -180,9 +194,13 @@ class CommonRoutingTableImpl(RoutingTable): async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -237,7 +255,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -257,7 +277,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse( + data=await self.get_all_with_type(ResourceType.shield.value) + ) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier("shield", identifier) @@ -316,14 +338,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError("No provider available. Please configure a vector_io provider.") + raise ValueError( + "No provider available. Please configure a vector_io provider." + ) model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + raise ValueError( + f"Model {embedding_model} does not have an embedding dimension" + ) vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -345,22 +371,37 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse( + data=await self.get_all_with_type(ResourceType.dataset.value) + ) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) async def register_dataset( self, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, + purpose: DatasetPurpose, + source: DataSource, metadata: Optional[Dict[str, Any]] = None, - ) -> None: - if provider_dataset_id is None: - provider_dataset_id = dataset_id + dataset_id: Optional[str] = None, + ) -> Dataset: + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if source.type == DatasetType.rows: + provider_id = "localfs" + elif source.type == DatasetType.uri: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + if provider_id is None: # If provider_id not specified, use the only provider if it supports this dataset if len(self.impls_by_provider_id) == 1: @@ -371,15 +412,18 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): ) if metadata is None: metadata = {} + dataset = Dataset( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, - dataset_schema=dataset_schema, - url=url, + purpose=purpose, + source=source, metadata=metadata, ) + await self.register_object(dataset) + return dataset async def unregister_dataset(self, dataset_id: str) -> None: dataset = await self.get_dataset(dataset_id) @@ -390,7 +434,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -487,8 +533,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append(