fix router

This commit is contained in:
Xi Yan 2025-03-15 14:43:41 -07:00
parent a197101635
commit 081ec3131d
2 changed files with 97 additions and 31 deletions

View file

@ -13,6 +13,7 @@ from llama_stack.apis.common.content_types import (
URL, URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
Eval, Eval,
@ -537,21 +538,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown") logger.debug("DatasetIORouter.shutdown")
pass pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
rows_in_page: int, start_index: Optional[int] = None,
page_token: Optional[str] = None, limit: Optional[int] = None,
filter_condition: Optional[str] = None,
) -> IterrowsResponse: ) -> IterrowsResponse:
logger.debug( logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, rows_in_page={rows_in_page}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
return await self.routing_table.get_provider_impl(dataset_id).iterrows( return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, start_index=start_index,
page_token=page_token, limit=limit,
filter_condition=filter_condition,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -12,7 +13,14 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ( from llama_stack.apis.scoring_functions import (
@ -97,7 +105,9 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -132,7 +142,9 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -170,7 +182,9 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: async def get_object_by_identifier(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -180,9 +194,13 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier) await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries # if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0: if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -237,7 +255,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None: if model_type is None:
model_type = ModelType.llm model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding: if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata") raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
@ -257,7 +277,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Optional[Shield]: async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier) return await self.get_object_by_identifier("shield", identifier)
@ -316,14 +338,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else: else:
raise ValueError("No provider available. Please configure a vector_io provider.") raise ValueError(
"No provider available. Please configure a vector_io provider."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension") raise ValueError(
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_db_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
@ -345,22 +371,37 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id) return await self.get_object_by_identifier("dataset", dataset_id)
async def register_dataset( async def register_dataset(
self, self,
dataset_id: str, purpose: DatasetPurpose,
dataset_schema: Dict[str, ParamType], source: DataSource,
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> None: dataset_id: Optional[str] = None,
if provider_dataset_id is None: ) -> Dataset:
provider_dataset_id = dataset_id if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if source.type == DatasetType.rows:
provider_id = "localfs"
elif source.type == DatasetType.uri:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this dataset # If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1: if len(self.impls_by_provider_id) == 1:
@ -371,15 +412,18 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}
dataset = Dataset( dataset = Dataset(
identifier=dataset_id, identifier=dataset_id,
provider_resource_id=provider_dataset_id, provider_resource_id=provider_dataset_id,
provider_id=provider_id, provider_id=provider_id,
dataset_schema=dataset_schema, purpose=purpose,
url=url, source=source,
metadata=metadata, metadata=metadata,
) )
await self.register_object(dataset) await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None: async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id) dataset = await self.get_dataset(dataset_id)
@ -390,7 +434,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id) return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -487,8 +533,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs: for tool_def in tool_defs:
tools.append( tools.append(