From 298721c238246d6dd06439ec7d4006a466dc8bf8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 May 2025 23:15:05 -0700 Subject: [PATCH] chore: split routing_tables into individual files (#2259) --- llama_stack/distribution/routers/__init__.py | 18 +- .../distribution/routers/routing_tables.py | 634 ------------------ .../distribution/routing_tables/__init__.py | 5 + .../distribution/routing_tables/benchmarks.py | 58 ++ .../distribution/routing_tables/common.py | 218 ++++++ .../distribution/routing_tables/datasets.py | 93 +++ .../distribution/routing_tables/models.py | 82 +++ .../routing_tables/scoring_functions.py | 62 ++ .../distribution/routing_tables/shields.py | 57 ++ .../distribution/routing_tables/toolgroups.py | 98 +++ .../distribution/routing_tables/vector_dbs.py | 74 ++ pyproject.toml | 1 + .../routers/test_routing_tables.py | 16 +- tests/unit/server/test_access_control.py | 12 +- tests/unit/server/test_resolver.py | 2 +- 15 files changed, 770 insertions(+), 660 deletions(-) delete mode 100644 llama_stack/distribution/routers/routing_tables.py create mode 100644 llama_stack/distribution/routing_tables/__init__.py create mode 100644 llama_stack/distribution/routing_tables/benchmarks.py create mode 100644 llama_stack/distribution/routing_tables/common.py create mode 100644 llama_stack/distribution/routing_tables/datasets.py create mode 100644 llama_stack/distribution/routing_tables/models.py create mode 100644 llama_stack/distribution/routing_tables/scoring_functions.py create mode 100644 llama_stack/distribution/routing_tables/shields.py create mode 100644 llama_stack/distribution/routing_tables/toolgroups.py create mode 100644 llama_stack/distribution/routing_tables/vector_dbs.py diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index c1136a8d4..1358d5812 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,16 +12,6 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from .routing_tables import ( - BenchmarksRoutingTable, - DatasetsRoutingTable, - ModelsRoutingTable, - ScoringFunctionsRoutingTable, - ShieldsRoutingTable, - ToolGroupsRoutingTable, - VectorDBsRoutingTable, -) - async def get_routing_table_impl( api: Api, @@ -29,6 +19,14 @@ async def get_routing_table_impl( _deps, dist_registry: DistributionRegistry, ) -> Any: + from ..routing_tables.benchmarks import BenchmarksRoutingTable + from ..routing_tables.datasets import DatasetsRoutingTable + from ..routing_tables.models import ModelsRoutingTable + from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable + from ..routing_tables.shields import ShieldsRoutingTable + from ..routing_tables.toolgroups import ToolGroupsRoutingTable + from ..routing_tables.vector_dbs import VectorDBsRoutingTable + api_to_tables = { "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py deleted file mode 100644 index c04562197..000000000 --- a/llama_stack/distribution/routers/routing_tables.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import logging -import time -import uuid -from typing import Any - -from pydantic import TypeAdapter - -from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import ( - Dataset, - DatasetPurpose, - Datasets, - DatasetType, - DataSource, - ListDatasetsResponse, - RowsDataSource, - URIDataSource, -) -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ( - ListScoringFunctionsResponse, - ScoringFn, - ScoringFnParams, - ScoringFunctions, -) -from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields -from llama_stack.apis.tools import ( - ListToolGroupsResponse, - ListToolsResponse, - Tool, - ToolGroup, - ToolGroups, - ToolHost, -) -from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import ( - AccessAttributes, - BenchmarkWithACL, - DatasetWithACL, - ModelWithACL, - RoutableObject, - RoutableObjectWithProvider, - RoutedProtocol, - ScoringFnWithACL, - ShieldWithACL, - ToolGroupWithACL, - ToolWithACL, - VectorDBWithACL, -) -from llama_stack.distribution.request_headers import get_auth_attributes -from llama_stack.distribution.store import DistributionRegistry -from llama_stack.providers.datatypes import Api, RoutingTable - -logger = logging.getLogger(__name__) - - -def get_impl_api(p: Any) -> Api: - return p.__provider_spec__.api - - -# TODO: this should return the registered object for all APIs -async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: - api = get_impl_api(p) - - assert obj.provider_id != "remote", "Remote provider should not be registered" - - if api == Api.inference: - return await p.register_model(obj) - elif api == Api.safety: - return await p.register_shield(obj) - elif api == Api.vector_io: - return await p.register_vector_db(obj) - elif api == Api.datasetio: - return await p.register_dataset(obj) - elif api == Api.scoring: - return await p.register_scoring_function(obj) - elif api == Api.eval: - return await p.register_benchmark(obj) - elif api == Api.tool_runtime: - return await p.register_tool(obj) - else: - raise ValueError(f"Unknown API {api} for registering object with provider") - - -async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: - api = get_impl_api(p) - if api == Api.vector_io: - return await p.unregister_vector_db(obj.identifier) - elif api == Api.inference: - return await p.unregister_model(obj.identifier) - elif api == Api.datasetio: - return await p.unregister_dataset(obj.identifier) - elif api == Api.tool_runtime: - return await p.unregister_tool(obj.identifier) - else: - raise ValueError(f"Unregister not supported for {api}") - - -Registry = dict[str, list[RoutableObjectWithProvider]] - - -class CommonRoutingTableImpl(RoutingTable): - def __init__( - self, - impls_by_provider_id: dict[str, RoutedProtocol], - dist_registry: DistributionRegistry, - ) -> None: - self.impls_by_provider_id = impls_by_provider_id - self.dist_registry = dist_registry - - async def initialize(self) -> None: - async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: - for obj in objs: - if cls is None: - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) - await self.dist_registry.register(obj) - - # Register all objects from providers - for pid, p in self.impls_by_provider_id.items(): - api = get_impl_api(p) - if api == Api.inference: - p.model_store = self - elif api == Api.safety: - p.shield_store = self - elif api == Api.vector_io: - p.vector_db_store = self - elif api == Api.datasetio: - p.dataset_store = self - elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: - p.benchmark_store = self - elif api == Api.tool_runtime: - p.tool_store = self - - async def shutdown(self) -> None: - for p in self.impls_by_provider_id.values(): - await p.shutdown() - - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: - def apiname_object(): - if isinstance(self, ModelsRoutingTable): - return ("Inference", "model") - elif isinstance(self, ShieldsRoutingTable): - return ("Safety", "shield") - elif isinstance(self, VectorDBsRoutingTable): - return ("VectorIO", "vector_db") - elif isinstance(self, DatasetsRoutingTable): - return ("DatasetIO", "dataset") - elif isinstance(self, ScoringFunctionsRoutingTable): - return ("Scoring", "scoring_function") - elif isinstance(self, BenchmarksRoutingTable): - return ("Eval", "benchmark") - elif isinstance(self, ToolGroupsRoutingTable): - return ("Tools", "tool") - else: - raise ValueError("Unknown routing table type") - - apiname, objtype = apiname_object() - - # Get objects from disk registry - obj = self.dist_registry.get_cached(objtype, routing_key) - if not obj: - provider_ids = list(self.impls_by_provider_id.keys()) - if len(provider_ids) > 1: - provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" - else: - provider_ids_str = f"provider: `{provider_ids[0]}`" - raise ValueError( - f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." - ) - - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] - - raise ValueError(f"Provider not found for `{routing_key}`") - - async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - # Get from disk registry - obj = await self.dist_registry.get(type, identifier) - if not obj: - return None - - # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") - return None - - return obj - - async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: - await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: - # if provider_id is not specified, pick an arbitrary one from existing entries - if not obj.provider_id and len(self.impls_by_provider_id) > 0: - obj.provider_id = list(self.impls_by_provider_id.keys())[0] - - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") - - p = self.impls_by_provider_id[obj.provider_id] - - # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") - - registered_obj = await register_object_with_provider(obj, p) - # TODO: This needs to be fixed for all APIs once they return the registered object - if obj.type == ResourceType.model.value: - await self.dist_registry.register(registered_obj) - return registered_obj - - else: - await self.dist_registry.register(obj) - return obj - - async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: - objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] - - # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) - ] - - return filtered_objs - - -class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) - - async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") - openai_models = [ - OpenAIModel( - id=model.identifier, - object="model", - created=int(time.time()), - owned_by="llama_stack", - ) - for model in models - ] - return OpenAIListModelsResponse(data=openai_models) - - async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model - - async def register_model( - self, - model_id: str, - provider_model_id: str | None = None, - provider_id: str | None = None, - metadata: dict[str, Any] | None = None, - model_type: ModelType | None = None, - ) -> Model: - if provider_model_id is None: - provider_model_id = model_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this model - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm - if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( - identifier=model_id, - provider_resource_id=provider_model_id, - provider_id=provider_id, - metadata=metadata, - model_type=model_type, - ) - registered_model = await self.register_object(model) - return registered_model - - async def unregister_model(self, model_id: str) -> None: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - await self.unregister_object(existing_model) - - -class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): - async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - - async def get_shield(self, identifier: str) -> Shield: - shield = await self.get_object_by_identifier("shield", identifier) - if shield is None: - raise ValueError(f"Shield '{identifier}' not found") - return shield - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - if provider_shield_id is None: - provider_shield_id = shield_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if params is None: - params = {} - shield = ShieldWithACL( - identifier=shield_id, - provider_resource_id=provider_shield_id, - provider_id=provider_id, - params=params, - ) - await self.register_object(shield) - return shield - - -class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): - async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - - async def get_vector_db(self, vector_db_id: str) -> VectorDB: - vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) - if vector_db is None: - raise ValueError(f"Vector DB '{vector_db_id}' not found") - return vector_db - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id - if provider_id is None: - if len(self.impls_by_provider_id) > 0: - provider_id = list(self.impls_by_provider_id.keys())[0] - if len(self.impls_by_provider_id) > 1: - logger.warning( - f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." - ) - else: - raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) - if model is None: - raise ValueError(f"Model {embedding_model} not found") - if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") - if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") - vector_db_data = { - "identifier": vector_db_id, - "type": ResourceType.vector_db.value, - "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, - "embedding_model": embedding_model, - "embedding_dimension": model.metadata["embedding_dimension"], - } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) - await self.register_object(vector_db) - return vector_db - - async def unregister_vector_db(self, vector_db_id: str) -> None: - existing_vector_db = await self.get_vector_db(vector_db_id) - if existing_vector_db is None: - raise ValueError(f"Vector DB {vector_db_id} not found") - await self.unregister_object(existing_vector_db) - - -class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): - async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - - async def get_dataset(self, dataset_id: str) -> Dataset: - dataset = await self.get_object_by_identifier("dataset", dataset_id) - if dataset is None: - raise ValueError(f"Dataset '{dataset_id}' not found") - return dataset - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> Dataset: - if isinstance(source, dict): - if source["type"] == "uri": - source = URIDataSource.parse_obj(source) - elif source["type"] == "rows": - source = RowsDataSource.parse_obj(source) - - if not dataset_id: - dataset_id = f"dataset-{str(uuid.uuid4())}" - - provider_dataset_id = dataset_id - - # infer provider from source - if metadata: - if metadata.get("provider_id"): - provider_id = metadata.get("provider_id") # pass through from nvidia datasetio - elif source.type == DatasetType.rows.value: - provider_id = "localfs" - elif source.type == DatasetType.uri.value: - # infer provider from uri - if source.uri.startswith("huggingface"): - provider_id = "huggingface" - else: - provider_id = "localfs" - else: - raise ValueError(f"Unknown data source type: {source.type}") - - if metadata is None: - metadata = {} - - dataset = DatasetWithACL( - identifier=dataset_id, - provider_resource_id=provider_dataset_id, - provider_id=provider_id, - purpose=purpose, - source=source, - metadata=metadata, - ) - - await self.register_object(dataset) - return dataset - - async def unregister_dataset(self, dataset_id: str) -> None: - dataset = await self.get_dataset(dataset_id) - if dataset is None: - raise ValueError(f"Dataset {dataset_id} not found") - await self.unregister_object(dataset) - - -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - - async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) - if scoring_fn is None: - raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn - - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: str | None = None, - provider_id: str | None = None, - params: ScoringFnParams | None = None, - ) -> None: - if provider_scoring_fn_id is None: - provider_scoring_fn_id = scoring_fn_id - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - scoring_fn = ScoringFnWithACL( - identifier=scoring_fn_id, - description=description, - return_type=return_type, - provider_resource_id=provider_scoring_fn_id, - provider_id=provider_id, - params=params, - ) - scoring_fn.provider_id = provider_id - await self.register_object(scoring_fn) - - -class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): - async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - - async def get_benchmark(self, benchmark_id: str) -> Benchmark: - benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) - if benchmark is None: - raise ValueError(f"Benchmark '{benchmark_id}' not found") - return benchmark - - async def register_benchmark( - self, - benchmark_id: str, - dataset_id: str, - scoring_functions: list[str], - metadata: dict[str, Any] | None = None, - provider_benchmark_id: str | None = None, - provider_id: str | None = None, - ) -> None: - if metadata is None: - metadata = {} - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if provider_benchmark_id is None: - provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( - identifier=benchmark_id, - dataset_id=dataset_id, - scoring_functions=scoring_functions, - metadata=metadata, - provider_id=provider_id, - provider_resource_id=provider_benchmark_id, - ) - await self.register_object(benchmark) - - -class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): - async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") - if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) - - async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) - - async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group '{toolgroup_id}' not found") - return tool_group - - async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) - - async def register_tool_group( - self, - toolgroup_id: str, - provider_id: str, - mcp_endpoint: URL | None = None, - args: dict[str, Any] | None = None, - ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - - await self.dist_registry.register( - ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) - ) - - async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) - await self.unregister_object(tool_group) - - async def shutdown(self) -> None: - pass diff --git a/llama_stack/distribution/routing_tables/__init__.py b/llama_stack/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py new file mode 100644 index 000000000..589a00c02 --- /dev/null +++ b/llama_stack/distribution/routing_tables/benchmarks.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse +from llama_stack.distribution.datatypes import ( + BenchmarkWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): + async def list_benchmarks(self) -> ListBenchmarksResponse: + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark + + async def register_benchmark( + self, + benchmark_id: str, + dataset_id: str, + scoring_functions: list[str], + metadata: dict[str, Any] | None = None, + provider_benchmark_id: str | None = None, + provider_id: str | None = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_benchmark_id is None: + provider_benchmark_id = benchmark_id + benchmark = BenchmarkWithACL( + identifier=benchmark_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_benchmark_id, + ) + await self.register_object(benchmark) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py new file mode 100644 index 000000000..95a92a5ba --- /dev/null +++ b/llama_stack/distribution/routing_tables/common.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import ( + AccessAttributes, + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) +from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.store import DistributionRegistry +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import Api, RoutingTable + +logger = get_logger(name=__name__, category="core") + + +def get_impl_api(p: Any) -> Api: + return p.__provider_spec__.api + + +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) + + assert obj.provider_id != "remote", "Remote provider should not be registered" + + if api == Api.inference: + return await p.register_model(obj) + elif api == Api.safety: + return await p.register_shield(obj) + elif api == Api.vector_io: + return await p.register_vector_db(obj) + elif api == Api.datasetio: + return await p.register_dataset(obj) + elif api == Api.scoring: + return await p.register_scoring_function(obj) + elif api == Api.eval: + return await p.register_benchmark(obj) + elif api == Api.tool_runtime: + return await p.register_tool(obj) + else: + raise ValueError(f"Unknown API {api} for registering object with provider") + + +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.vector_io: + return await p.unregister_vector_db(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_tool(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + +Registry = dict[str, list[RoutableObjectWithProvider]] + + +class CommonRoutingTableImpl(RoutingTable): + def __init__( + self, + impls_by_provider_id: dict[str, RoutedProtocol], + dist_registry: DistributionRegistry, + ) -> None: + self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry + + async def initialize(self) -> None: + async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: + for obj in objs: + if cls is None: + obj.provider_id = provider_id + else: + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) + await self.dist_registry.register(obj) + + # Register all objects from providers + for pid, p in self.impls_by_provider_id.items(): + api = get_impl_api(p) + if api == Api.inference: + p.model_store = self + elif api == Api.safety: + p.shield_store = self + elif api == Api.vector_io: + p.vector_db_store = self + elif api == Api.datasetio: + p.dataset_store = self + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + await add_objects(scoring_functions, pid, ScoringFn) + elif api == Api.eval: + p.benchmark_store = self + elif api == Api.tool_runtime: + p.tool_store = self + + async def shutdown(self) -> None: + for p in self.impls_by_provider_id.values(): + await p.shutdown() + + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + from .benchmarks import BenchmarksRoutingTable + from .datasets import DatasetsRoutingTable + from .models import ModelsRoutingTable + from .scoring_functions import ScoringFunctionsRoutingTable + from .shields import ShieldsRoutingTable + from .toolgroups import ToolGroupsRoutingTable + from .vector_dbs import VectorDBsRoutingTable + + def apiname_object(): + if isinstance(self, ModelsRoutingTable): + return ("Inference", "model") + elif isinstance(self, ShieldsRoutingTable): + return ("Safety", "shield") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") + elif isinstance(self, BenchmarksRoutingTable): + return ("Eval", "benchmark") + elif isinstance(self, ToolGroupsRoutingTable): + return ("Tools", "tool") + else: + raise ValueError("Unknown routing table type") + + apiname, objtype = apiname_object() + + # Get objects from disk registry + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: + provider_ids = list(self.impls_by_provider_id.keys()) + if len(provider_ids) > 1: + provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" + else: + provider_ids_str = f"provider: `{provider_ids[0]}`" + raise ValueError( + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." + ) + + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] + + raise ValueError(f"Provider not found for `{routing_key}`") + + async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + # Get from disk registry + obj = await self.dist_registry.get(type, identifier) + if not obj: + return None + + # Check if user has permission to access this object + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + return None + + return obj + + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + # if provider_id is not specified, pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + + p = self.impls_by_provider_id[obj.provider_id] + + # If object supports access control but no attributes set, use creator's attributes + if not obj.access_attributes: + creator_attributes = get_auth_attributes() + if creator_attributes: + obj.access_attributes = AccessAttributes(**creator_attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj + + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + filtered_objs = [obj for obj in objs if obj.type == type] + + # Apply attribute-based access control filtering + if filtered_objs: + filtered_objs = [ + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + ] + + return filtered_objs diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py new file mode 100644 index 000000000..4401ad47e --- /dev/null +++ b/llama_stack/distribution/routing_tables/datasets.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from typing import Any + +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, + RowsDataSource, + URIDataSource, +) +from llama_stack.apis.resource import ResourceType +from llama_stack.distribution.datatypes import ( + DatasetWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): + async def list_datasets(self) -> ListDatasetsResponse: + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if metadata: + if metadata.get("provider_id"): + provider_id = metadata.get("provider_id") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: + provider_id = "localfs" + elif source.type == DatasetType.uri.value: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + + if metadata is None: + metadata = {} + + dataset = DatasetWithACL( + identifier=dataset_id, + provider_resource_id=provider_dataset_id, + provider_id=provider_id, + purpose=purpose, + source=source, + metadata=metadata, + ) + + await self.register_object(dataset) + return dataset + + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py new file mode 100644 index 000000000..7216d9935 --- /dev/null +++ b/llama_stack/distribution/routing_tables/models.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from typing import Any + +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel +from llama_stack.distribution.datatypes import ( + ModelWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def list_models(self) -> ListModelsResponse: + return ListModelsResponse(data=await self.get_all_with_type("model")) + + async def openai_list_models(self) -> OpenAIListModelsResponse: + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model + + async def register_model( + self, + model_id: str, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, + ) -> Model: + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + ) + if metadata is None: + metadata = {} + if model_type is None: + model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") + model = ModelWithACL( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + registered_model = await self.register_object(model) + return registered_model + + async def unregister_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.unregister_object(existing_model) diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py new file mode 100644 index 000000000..d85f64b57 --- /dev/null +++ b/llama_stack/distribution/routing_tables/scoring_functions.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ListScoringFunctionsResponse, + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.distribution.datatypes import ( + ScoringFnWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): + async def list_scoring_functions(self) -> ListScoringFunctionsResponse: + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn + + async def register_scoring_function( + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, + ) -> None: + if provider_scoring_fn_id is None: + provider_scoring_fn_id = scoring_fn_id + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + scoring_fn = ScoringFnWithACL( + identifier=scoring_fn_id, + description=description, + return_type=return_type, + provider_resource_id=provider_scoring_fn_id, + provider_id=provider_id, + params=params, + ) + scoring_fn.provider_id = provider_id + await self.register_object(scoring_fn) diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py new file mode 100644 index 000000000..7f62596c9 --- /dev/null +++ b/llama_stack/distribution/routing_tables/shields.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields +from llama_stack.distribution.datatypes import ( + ShieldWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def list_shields(self) -> ListShieldsResponse: + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + if provider_shield_id is None: + provider_shield_id = shield_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if params is None: + params = {} + shield = ShieldWithACL( + identifier=shield_id, + provider_resource_id=provider_shield_id, + provider_id=provider_id, + params=params, + ) + await self.register_object(shield) + return shield diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py new file mode 100644 index 000000000..cb73dc7c2 --- /dev/null +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost +from llama_stack.distribution.datatypes import ( + ToolGroupWithACL, + ToolWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: + tools = await self.get_all_with_type("tool") + if toolgroup_id: + tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] + return ListToolsResponse(data=tools) + + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group + + async def get_tool(self, tool_name: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_name) + + async def register_tool_group( + self, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, + ) -> None: + tools = [] + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + + for tool_def in tool_defs.data: + tools.append( + ToolWithACL( + identifier=tool_def.name, + toolgroup_id=toolgroup_id, + description=tool_def.description or "", + parameters=tool_def.parameters or [], + provider_id=provider_id, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, + ) + ) + for tool in tools: + existing_tool = await self.get_tool(tool.identifier) + # Compare existing and new object if one exists + if existing_tool: + existing_dict = existing_tool.model_dump() + new_dict = tool.model_dump() + + if existing_dict != new_dict: + raise ValueError( + f"Object {tool.identifier} already exists in registry. Please use a different identifier." + ) + await self.register_object(tool) + + await self.dist_registry.register( + ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, + ) + ) + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + tool_group = await self.get_tool_group(toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group {toolgroup_id} not found") + tools = await self.list_tools(toolgroup_id) + for tool in getattr(tools, "data", []): + await self.unregister_object(tool) + await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py new file mode 100644 index 000000000..dc6c0d0ef --- /dev/null +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import TypeAdapter + +from llama_stack.apis.models import ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.datatypes import ( + VectorDBWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): + async def list_vector_dbs(self) -> ListVectorDBsResponse: + return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id + if provider_id is None: + if len(self.impls_by_provider_id) > 0: + provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) + else: + raise ValueError("No provider available. Please configure a vector_io provider.") + model = await self.get_object_by_identifier("model", embedding_model) + if model is None: + raise ValueError(f"Model {embedding_model} not found") + if model.model_type != ModelType.embedding: + raise ValueError(f"Model {embedding_model} is not an embedding model") + if "embedding_dimension" not in model.metadata: + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + vector_db_data = { + "identifier": vector_db_id, + "type": ResourceType.vector_db.value, + "provider_id": provider_id, + "provider_resource_id": provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], + } + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db + + async def unregister_vector_db(self, vector_db_id: str) -> None: + existing_vector_db = await self.get_vector_db(vector_db_id) + if existing_vector_db is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + await self.unregister_object(existing_vector_db) diff --git a/pyproject.toml b/pyproject.toml index 8d8137233..0b3b30a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,6 +219,7 @@ exclude = [ "^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/request_headers\\.py$", "^llama_stack/distribution/routers/", + "^llama_stack/distribution/routing_tables/", "^llama_stack/distribution/server/endpoints\\.py$", "^llama_stack/distribution/server/server\\.py$", "^llama_stack/distribution/stack\\.py$", diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 4e6585ad6..b5db6854a 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -17,15 +17,13 @@ from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter from llama_stack.apis.vector_dbs.vector_dbs import VectorDB -from llama_stack.distribution.routers.routing_tables import ( - BenchmarksRoutingTable, - DatasetsRoutingTable, - ModelsRoutingTable, - ScoringFunctionsRoutingTable, - ShieldsRoutingTable, - ToolGroupsRoutingTable, - VectorDBsRoutingTable, -) +from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable +from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable +from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable +from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable +from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index b5e9c2698..e352ba54d 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -11,7 +11,7 @@ import pytest from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL -from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable class AsyncMock(MagicMock): @@ -37,7 +37,7 @@ async def test_setup(cached_disk_dist_registry): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -102,7 +102,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -132,7 +132,7 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model = ModelWithACL( @@ -154,7 +154,7 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_no_user_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -185,7 +185,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): """Test that newly created resources inherit access attributes from their creator.""" registry, routing_table = test_setup diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 821ddf1fe..bb4c15dbc 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -20,7 +20,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.routers.inference import InferenceRouter -from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec