From e0ce15b8b5c663e8f235e68ead03c6d528e8cac2 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Tue, 25 Mar 2025 10:32:08 +0000 Subject: [PATCH] :rewind: ensuring formatting changes are reverted in routing_tables.py while maintaining functional changes in the `ShieldsRoutingTable` --- .../distribution/routers/routing_tables.py | 92 +++++-------------- 1 file changed, 25 insertions(+), 67 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 048a67c3b..20cab6df9 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -118,9 +118,7 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects( - objs: List[RoutableObjectWithProvider], provider_id: str, cls - ) -> None: + async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -155,9 +153,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl( - self, routing_key: str, provider_id: Optional[str] = None - ) -> Any: + def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -195,32 +191,24 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier( - self, type: str, identifier: str - ) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: return None # Check if user has permission to access this object - if not check_access(obj, get_auth_attributes()): - logger.debug( - f"Access denied to {type} '{identifier}' based on attribute mismatch" - ) + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") return None return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider( - obj, self.impls_by_provider_id[obj.provider_id] - ) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - async def register_object( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -235,9 +223,7 @@ class CommonRoutingTableImpl(RoutingTable): creator_attributes = get_auth_attributes() if creator_attributes: obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info( - f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity" - ) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object @@ -256,7 +242,9 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ - obj for obj in filtered_objs if check_access(obj, get_auth_attributes()) + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) ] return filtered_objs @@ -295,9 +283,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError( - "Embedding model must have an embedding dimension in its metadata" - ) + raise ValueError("Embedding model must have an embedding dimension in its metadata") model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, @@ -331,39 +317,25 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): # Fetch shields after initialization - with robust error handling try: # Check if the provider implements list_shields - if hasattr(provider, "list_shields") and callable( - getattr(provider, "list_shields") - ): + if hasattr(provider, "list_shields") and callable(getattr(provider, "list_shields")): shields_response = await provider.list_shields() - if ( - shields_response - and hasattr(shields_response, "data") - and shields_response.data - ): + if shields_response and hasattr(shields_response, "data") and shields_response.data: for shield in shields_response.data: # Ensure type is set if not hasattr(shield, "type") or not shield.type: shield.type = ResourceType.shield.value await self.dist_registry.register(shield) - logger.info( - f"Registered {len(shields_response.data)} shields from provider {provider_id}" - ) + logger.info(f"Registered {len(shields_response.data)} shields from provider {provider_id}") else: logger.info(f"No shields found for provider {provider_id}") else: - logger.info( - f"Provider {provider_id} does not support listing shields" - ) + logger.info(f"Provider {provider_id} does not support listing shields") except Exception as e: # Log the error but continue initialization - logger.warning( - f"Error listing shields from provider {provider_id}: {str(e)}" - ) + logger.warning(f"Error listing shields from provider {provider_id}: {str(e)}") async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse( - data=await self.get_all_with_type(ResourceType.shield.value) - ) + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) @@ -428,18 +400,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError( - "No provider available. Please configure a vector_io provider." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError( - f"Model {embedding_model} does not have an embedding dimension" - ) + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -461,9 +429,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse( - data=await self.get_all_with_type(ResourceType.dataset.value) - ) + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) @@ -525,14 +491,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse( - data=await self.get_all_with_type(ResourceType.scoring_function.value) - ) + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier( - "scoring_function", scoring_fn_id - ) + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) if scoring_fn is None: raise ValueError(f"Scoring function '{scoring_fn_id}' not found") return scoring_fn @@ -635,12 +597,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( - toolgroup_id, mcp_endpoint - ) - tool_host = ( - ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - ) + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution for tool_def in tool_defs: tools.append(