diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..7e2fc456c 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -171,7 +171,7 @@ class VectorIORouter(VectorIO): logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") # Route to default provider for now - could aggregate from all providers in the future # call retrieve on each vector dbs to get list of vector stores - vector_dbs = await self.routing_table.get_all_with_type("vector_db") + vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db") all_stores = [] for vector_db in vector_dbs: try: diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..59e3025be 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -19,7 +19,7 @@ logger = get_logger(name=__name__, category="core") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark")) async def get_benchmark(self, benchmark_id: str) -> Benchmark: benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..6131a1573 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -232,15 +232,21 @@ class CommonRoutingTableImpl(RoutingTable): async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] + return [obj for obj in objs if obj.type == type] + + async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]: + all_objs = await self.get_all_with_type(type=type) # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) + if all_objs: + all_objs = [ + obj + for obj in all_objs + if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) + and obj.provider_id in self.impls_by_provider_id ] - return filtered_objs + return all_objs async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: @@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> ) # if not found, this means model_id is an unscoped provider_model_id, we need # to iterate (given a lack of an efficient index on the KVStore) - models = await routing_table.get_all_with_type("model") + models = await routing_table.get_all_with_type_filtered("model") matching_models = [m for m in models if m.provider_resource_id == model_id] if len(matching_models) == 0: raise ModelNotFoundError(model_id) diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..51fa2bec6 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -31,7 +31,7 @@ logger = get_logger(name=__name__, category="core") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value)) async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e00..8155ccc33 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,10 +43,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + return ListModelsResponse(data=await self.get_all_with_type_filtered("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") + models = await self.get_all_with_type_filtered("model") openai_models = [ OpenAIModel( id=model.identifier, @@ -122,7 +122,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id: str, models: list[Model], ) -> None: - existing_models = await self.get_all_with_type("model") + existing_models = await self.get_all_with_type_filtered("model") # we may have an alias for the model registered by the user (or during initialization # from run.yaml) that we need to keep track of diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..9907522c4 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -24,7 +24,9 @@ logger = get_logger(name=__name__, category="core") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..0e9d78a31 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="core") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value)) async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b3906..82f7cb819 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): toolgroup_id = group_id toolgroups = [await self.get_tool_group(toolgroup_id)] else: - toolgroups = await self.get_all_with_type("tool_group") + toolgroups = await self.get_all_with_type_filtered("tool_group") all_tools = [] for toolgroup in toolgroups: @@ -83,7 +83,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group")) async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc46997..b2f252dcb 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -35,7 +35,7 @@ logger = get_logger(name=__name__, category="core") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db")) async def get_vector_db(self, vector_db_id: str) -> VectorDB: vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)