# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control import check_access from llama_stack.distribution.datatypes import ( AccessAttributes, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, ) from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable logger = get_logger(name=__name__, category="core") def get_impl_api(p: Any) -> Api: return p.__provider_spec__.api # TODO: this should return the registered object for all APIs async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: api = get_impl_api(p) assert obj.provider_id != "remote", "Remote provider should not be registered" if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: return await p.register_shield(obj) elif api == Api.vector_io: return await p.register_vector_db(obj) elif api == Api.datasetio: return await p.register_dataset(obj) elif api == Api.scoring: return await p.register_scoring_function(obj) elif api == Api.eval: return await p.register_benchmark(obj) elif api == Api.tool_runtime: return await p.register_toolgroup(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if api == Api.vector_io: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: return await p.unregister_toolgroup(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") Registry = dict[str, list[RoutableObjectWithProvider]] class CommonRoutingTableImpl(RoutingTable): def __init__( self, impls_by_provider_id: dict[str, RoutedProtocol], dist_registry: DistributionRegistry, ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry async def initialize(self) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id else: # Create a copy of the model data and explicitly set provider_id model_data = obj.model_dump() model_data["provider_id"] = provider_id obj = cls(**model_data) await self.dist_registry.register(obj) # Register all objects from providers for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) if api == Api.inference: p.model_store = self elif api == Api.safety: p.shield_store = self elif api == Api.vector_io: p.vector_db_store = self elif api == Api.datasetio: p.dataset_store = self elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.benchmark_store = self elif api == Api.tool_runtime: p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable from .models import ModelsRoutingTable from .scoring_functions import ScoringFunctionsRoutingTable from .shields import ShieldsRoutingTable from .toolgroups import ToolGroupsRoutingTable from .vector_dbs import VectorDBsRoutingTable def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") elif isinstance(self, ShieldsRoutingTable): return ("Safety", "shield") elif isinstance(self, VectorDBsRoutingTable): return ("VectorIO", "vector_db") elif isinstance(self, DatasetsRoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): return ("Scoring", "scoring_function") elif isinstance(self, BenchmarksRoutingTable): return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): return ("ToolGroups", "tool_group") else: raise ValueError("Unknown routing table type") apiname, objtype = apiname_object() # Get objects from disk registry obj = self.dist_registry.get_cached(objtype, routing_key) if not obj: provider_ids = list(self.impls_by_provider_id.keys()) if len(provider_ids) > 1: provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" else: provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." ) if not provider_id or provider_id == obj.provider_id: return self.impls_by_provider_id[obj.provider_id] raise ValueError(f"Provider not found for `{routing_key}`") async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: return None # Check if user has permission to access this object if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") return None return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] if obj.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider `{obj.provider_id}` not found") p = self.impls_by_provider_id[obj.provider_id] # If object supports access control but no attributes set, use creator's attributes if not obj.access_attributes: creator_attributes = get_auth_attributes() if creator_attributes: obj.access_attributes = AccessAttributes(**creator_attributes) logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj else: await self.dist_registry.register(obj) return obj async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ obj for obj in filtered_objs if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) ] return filtered_objs