From 2692d73d7d6c482d33750a17a610dd214ea5b8a2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 19 Mar 2025 15:09:41 -0700 Subject: [PATCH] update routing table registration types --- llama_stack/distribution/datatypes.py | 16 ++++++------- .../distribution/routers/routing_tables.py | 24 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index aebf8b0e7..48f1925dd 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -142,14 +142,14 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL): RoutableObject = Union[ - ModelWithACL, - ShieldWithACL, - VectorDBWithACL, - DatasetWithACL, - ScoringFnWithACL, - BenchmarkWithACL, - ToolWithACL, - ToolGroupWithACL, + Model, + Shield, + VectorDB, + Dataset, + ScoringFn, + Benchmark, + Tool, + ToolGroup, ] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f756c8621..8c99ef6be 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -42,9 +42,17 @@ from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorD from llama_stack.distribution.access_control import check_access from llama_stack.distribution.datatypes import ( AccessAttributes, + BenchmarkWithACL, + DatasetWithACL, + ModelWithACL, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, + ScoringFnWithACL, + ShieldWithACL, + ToolGroupWithACL, + ToolWithACL, + VectorDBWithACL, ) from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.store import DistributionRegistry @@ -270,7 +278,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = Model( + model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, @@ -316,7 +324,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) if params is None: params = {} - shield = Shield( + shield = ShieldWithACL( identifier=shield_id, provider_resource_id=provider_shield_id, provider_id=provider_id, @@ -370,7 +378,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], } - vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) await self.register_object(vector_db) return vector_db @@ -418,7 +426,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): if metadata is None: metadata = {} - dataset = Dataset( + dataset = DatasetWithACL( identifier=dataset_id, provider_resource_id=provider_dataset_id, provider_id=provider_id, @@ -465,7 +473,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - scoring_fn = ScoringFn( + scoring_fn = ScoringFnWithACL( identifier=scoring_fn_id, description=description, return_type=return_type, @@ -507,7 +515,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): ) if provider_benchmark_id is None: provider_benchmark_id = benchmark_id - benchmark = Benchmark( + benchmark = BenchmarkWithACL( identifier=benchmark_id, dataset_id=dataset_id, scoring_functions=scoring_functions, @@ -550,7 +558,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): for tool_def in tool_defs: tools.append( - Tool( + ToolWithACL( identifier=tool_def.name, toolgroup_id=toolgroup_id, description=tool_def.description or "", @@ -575,7 +583,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): await self.register_object(tool) await self.dist_registry.register( - ToolGroup( + ToolGroupWithACL( identifier=toolgroup_id, provider_id=provider_id, provider_resource_id=toolgroup_id,