diff --git a/pyproject.toml b/pyproject.toml index 999c3d9a3..00efe4bdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -274,7 +274,6 @@ exclude = [ "^src/llama_stack/core/client\\.py$", "^src/llama_stack/core/request_headers\\.py$", "^src/llama_stack/core/routers/", - "^src/llama_stack/core/routing_tables/", "^src/llama_stack/core/server/endpoints\\.py$", "^src/llama_stack/core/server/server\\.py$", "^src/llama_stack/core/stack\\.py$", diff --git a/src/llama_stack/core/routing_tables/benchmarks.py b/src/llama_stack/core/routing_tables/benchmarks.py index 8c87d395d..e8da4eca7 100644 --- a/src/llama_stack/core/routing_tables/benchmarks.py +++ b/src/llama_stack/core/routing_tables/benchmarks.py @@ -19,15 +19,15 @@ logger = get_logger(name=__name__, category="core::routing_tables") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) # type: ignore[arg-type] async def get_benchmark(self, benchmark_id: str) -> Benchmark: benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) if benchmark is None: raise ValueError(f"Benchmark '{benchmark_id}' not found") - return benchmark + return benchmark # type: ignore[return-value] - async def register_benchmark( + async def register_benchmark( # type: ignore[override] self, benchmark_id: str, dataset_id: str, @@ -59,4 +59,4 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def unregister_benchmark(self, benchmark_id: str) -> None: existing_benchmark = await self.get_benchmark(benchmark_id) - await self.unregister_object(existing_benchmark) + await self.unregister_object(existing_benchmark) # type: ignore[arg-type] diff --git a/src/llama_stack/core/routing_tables/common.py b/src/llama_stack/core/routing_tables/common.py index d6faf93c5..0fcd71e51 100644 --- a/src/llama_stack/core/routing_tables/common.py +++ b/src/llama_stack/core/routing_tables/common.py @@ -27,7 +27,7 @@ logger = get_logger(name=__name__, category="core::routing_tables") def get_impl_api(p: Any) -> Api: - return p.__provider_spec__.api + return p.__provider_spec__.api # type: ignore[no-any-return] # TODO: this should return the registered object for all APIs @@ -37,19 +37,19 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable assert obj.provider_id != "remote", "Remote provider should not be registered" if api == Api.inference: - return await p.register_model(obj) + return await p.register_model(obj) # type: ignore[no-any-return] elif api == Api.safety: - return await p.register_shield(obj) + return await p.register_shield(obj) # type: ignore[no-any-return] elif api == Api.vector_io: - return await p.register_vector_store(obj) + return await p.register_vector_store(obj) # type: ignore[no-any-return] elif api == Api.datasetio: - return await p.register_dataset(obj) + return await p.register_dataset(obj) # type: ignore[no-any-return] elif api == Api.scoring: - return await p.register_scoring_function(obj) + return await p.register_scoring_function(obj) # type: ignore[no-any-return] elif api == Api.eval: - return await p.register_benchmark(obj) + return await p.register_benchmark(obj) # type: ignore[no-any-return] elif api == Api.tool_runtime: - return await p.register_toolgroup(obj) + return await p.register_toolgroup(obj) # type: ignore[no-any-return] else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -57,19 +57,19 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) if api == Api.vector_io: - return await p.unregister_vector_store(obj.identifier) + await p.unregister_vector_store(obj.identifier) elif api == Api.inference: - return await p.unregister_model(obj.identifier) + await p.unregister_model(obj.identifier) elif api == Api.safety: - return await p.unregister_shield(obj.identifier) + await p.unregister_shield(obj.identifier) elif api == Api.datasetio: - return await p.unregister_dataset(obj.identifier) + await p.unregister_dataset(obj.identifier) elif api == Api.eval: - return await p.unregister_benchmark(obj.identifier) + await p.unregister_benchmark(obj.identifier) elif api == Api.scoring: - return await p.unregister_scoring_function(obj.identifier) + await p.unregister_scoring_function(obj.identifier) elif api == Api.tool_runtime: - return await p.unregister_toolgroup(obj.identifier) + await p.unregister_toolgroup(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -104,25 +104,25 @@ class CommonRoutingTableImpl(RoutingTable): for pid, p in self.impls_by_provider_id.items(): api = get_impl_api(p) if api == Api.inference: - p.model_store = self + p.model_store = self # type: ignore[union-attr] elif api == Api.safety: - p.shield_store = self + p.shield_store = self # type: ignore[union-attr] elif api == Api.vector_io: - p.vector_store_store = self + p.vector_store_store = self # type: ignore[union-attr] elif api == Api.datasetio: - p.dataset_store = self + p.dataset_store = self # type: ignore[union-attr] elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() + p.scoring_function_store = self # type: ignore[union-attr] + scoring_functions = await p.list_scoring_functions() # type: ignore[union-attr] await add_objects(scoring_functions, pid, ScoringFnWithOwner) elif api == Api.eval: - p.benchmark_store = self + p.benchmark_store = self # type: ignore[union-attr] elif api == Api.tool_runtime: - p.tool_store = self + p.tool_store = self # type: ignore[union-attr] async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): - await p.shutdown() + await p.shutdown() # type: ignore[union-attr] async def refresh(self) -> None: pass @@ -180,7 +180,7 @@ class CommonRoutingTableImpl(RoutingTable): return None # Check if user has permission to access this object - if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()): + if not is_action_allowed(self.policy, Action.READ, obj, get_authenticated_user()): # type: ignore[arg-type] logger.debug(f"Access denied to {type} '{identifier}'") return None @@ -188,8 +188,8 @@ class CommonRoutingTableImpl(RoutingTable): async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: user = get_authenticated_user() - if not is_action_allowed(self.policy, "delete", obj, user): - raise AccessDeniedError("delete", obj, user) + if not is_action_allowed(self.policy, Action.DELETE, obj, user): # type: ignore[arg-type] + raise AccessDeniedError(Action.DELETE, obj, user) # type: ignore[arg-type] await self.dist_registry.delete(obj.type, obj.identifier) await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) @@ -205,8 +205,8 @@ class CommonRoutingTableImpl(RoutingTable): # If object supports access control but no attributes set, use creator's attributes creator = get_authenticated_user() - if not is_action_allowed(self.policy, "create", obj, creator): - raise AccessDeniedError("create", obj, creator) + if not is_action_allowed(self.policy, Action.CREATE, obj, creator): # type: ignore[arg-type] + raise AccessDeniedError(Action.CREATE, obj, creator) # type: ignore[arg-type] if creator: obj.owner = creator logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") @@ -214,8 +214,8 @@ class CommonRoutingTableImpl(RoutingTable): registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object if obj.type == ResourceType.model.value: - await self.dist_registry.register(registered_obj) - return registered_obj + await self.dist_registry.register(registered_obj) # type: ignore[arg-type] + return registered_obj # type: ignore[return-value] else: await self.dist_registry.register(obj) return obj @@ -231,8 +231,8 @@ class CommonRoutingTableImpl(RoutingTable): if obj is None: raise ValueError(f"{type.capitalize()} '{identifier}' not found") user = get_authenticated_user() - if not is_action_allowed(self.policy, action, obj, user): - raise AccessDeniedError(action, obj, user) + if not is_action_allowed(self.policy, action, obj, user): # type: ignore[arg-type] + raise AccessDeniedError(action, obj, user) # type: ignore[arg-type] async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() @@ -241,7 +241,9 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ - obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) + obj + for obj in filtered_objs + if is_action_allowed(self.policy, Action.READ, obj, get_authenticated_user()) # type: ignore[arg-type] ] return filtered_objs @@ -251,4 +253,4 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> model = await routing_table.get_object_by_identifier("model", model_id) if not model: raise ModelNotFoundError(model_id) - return model + return model # type: ignore[return-value] diff --git a/src/llama_stack/core/routing_tables/datasets.py b/src/llama_stack/core/routing_tables/datasets.py index b129c9ec5..5b60b774e 100644 --- a/src/llama_stack/core/routing_tables/datasets.py +++ b/src/llama_stack/core/routing_tables/datasets.py @@ -31,13 +31,13 @@ logger = get_logger(name=__name__, category="core::routing_tables") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) # type: ignore[arg-type] async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) if dataset is None: raise DatasetNotFoundError(dataset_id) - return dataset + return dataset # type: ignore[return-value] async def register_dataset( self, @@ -77,7 +77,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): dataset = DatasetWithOwner( identifier=dataset_id, provider_resource_id=provider_dataset_id, - provider_id=provider_id, + provider_id=provider_id, # type: ignore[arg-type] purpose=purpose, source=source, metadata=metadata, @@ -88,4 +88,4 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def unregister_dataset(self, dataset_id: str) -> None: dataset = await self.get_dataset(dataset_id) - await self.unregister_object(dataset) + await self.unregister_object(dataset) # type: ignore[arg-type] diff --git a/src/llama_stack/core/routing_tables/models.py b/src/llama_stack/core/routing_tables/models.py index 7e43d7273..cb26e4160 100644 --- a/src/llama_stack/core/routing_tables/models.py +++ b/src/llama_stack/core/routing_tables/models.py @@ -25,13 +25,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def refresh(self) -> None: for provider_id, provider in self.impls_by_provider_id.items(): - refresh = await provider.should_refresh_models() + refresh = await provider.should_refresh_models() # type: ignore[union-attr] refresh = refresh or provider_id not in self.listed_providers if not refresh: continue try: - models = await provider.list_models() + models = await provider.list_models() # type: ignore[union-attr] except Exception as e: logger.warning(f"Model refresh failed for provider {provider_id}: {e}") continue @@ -43,7 +43,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + return ListModelsResponse(data=await self.get_all_with_type("model")) # type: ignore[arg-type] async def openai_list_models(self) -> OpenAIListModelsResponse: models = await self.get_all_with_type("model") @@ -61,8 +61,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def get_model(self, model_id: str) -> Model: return await lookup_model(self, model_id) - async def get_provider_impl(self, model_id: str) -> Any: - model = await lookup_model(self, model_id) + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + model = await lookup_model(self, routing_key) if model.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] @@ -114,13 +114,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): source=RegistryEntrySource.via_register_api, ) registered_model = await self.register_object(model) - return registered_model + return registered_model # type: ignore[return-value] async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: raise ModelNotFoundError(model_id) - await self.unregister_object(existing_model) + await self.unregister_object(existing_model) # type: ignore[arg-type] async def update_registered_models( self, @@ -142,22 +142,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): logger.debug(f"unregistering model {model.identifier}") await self.unregister_object(model) - for model in models: - if model.provider_resource_id in model_ids: + for provider_model in models: + if provider_model.provider_resource_id in model_ids: # avoid overwriting a non-provider-registered model entry continue - if model.identifier == model.provider_resource_id: - model.identifier = f"{provider_id}/{model.provider_resource_id}" + if provider_model.identifier == provider_model.provider_resource_id: + provider_model.identifier = f"{provider_id}/{provider_model.provider_resource_id}" - logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") + logger.debug(f"registering model {provider_model.identifier} ({provider_model.provider_resource_id})") await self.register_object( ModelWithOwner( - identifier=model.identifier, - provider_resource_id=model.provider_resource_id, + identifier=provider_model.identifier, + provider_resource_id=provider_model.provider_resource_id, provider_id=provider_id, - metadata=model.metadata, - model_type=model.model_type, + metadata=provider_model.metadata, + model_type=provider_model.model_type, source=RegistryEntrySource.listed_from_provider, ) ) diff --git a/src/llama_stack/core/routing_tables/scoring_functions.py b/src/llama_stack/core/routing_tables/scoring_functions.py index 520f07014..2fd37c0a1 100644 --- a/src/llama_stack/core/routing_tables/scoring_functions.py +++ b/src/llama_stack/core/routing_tables/scoring_functions.py @@ -24,13 +24,13 @@ logger = get_logger(name=__name__, category="core::routing_tables") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) # type: ignore[arg-type] async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) if scoring_fn is None: raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn + return scoring_fn # type: ignore[return-value] async def register_scoring_function( self, @@ -63,4 +63,4 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def unregister_scoring_function(self, scoring_fn_id: str) -> None: existing_scoring_fn = await self.get_scoring_function(scoring_fn_id) - await self.unregister_object(existing_scoring_fn) + await self.unregister_object(existing_scoring_fn) # type: ignore[arg-type] diff --git a/src/llama_stack/core/routing_tables/shields.py b/src/llama_stack/core/routing_tables/shields.py index b1918d20a..59f431573 100644 --- a/src/llama_stack/core/routing_tables/shields.py +++ b/src/llama_stack/core/routing_tables/shields.py @@ -20,13 +20,13 @@ logger = get_logger(name=__name__, category="core::routing_tables") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) # type: ignore[arg-type] async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) if shield is None: raise ValueError(f"Shield '{identifier}' not found") - return shield + return shield # type: ignore[return-value] async def register_shield( self, @@ -58,4 +58,4 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def unregister_shield(self, identifier: str) -> None: existing_shield = await self.get_shield(identifier) - await self.unregister_object(existing_shield) + await self.unregister_object(existing_shield) # type: ignore[arg-type] diff --git a/src/llama_stack/core/routing_tables/toolgroups.py b/src/llama_stack/core/routing_tables/toolgroups.py index 2d47bbb17..886adb092 100644 --- a/src/llama_stack/core/routing_tables/toolgroups.py +++ b/src/llama_stack/core/routing_tables/toolgroups.py @@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): toolgroup_id = group_id toolgroups = [await self.get_tool_group(toolgroup_id)] else: - toolgroups = await self.get_all_with_type("tool_group") + toolgroups = await self.get_all_with_type("tool_group") # type: ignore[assignment] all_tools = [] for toolgroup in toolgroups: @@ -83,13 +83,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): self.tool_to_toolgroup[tool.name] = toolgroup.identifier async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) # type: ignore[arg-type] async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) if tool_group is None: raise ToolGroupNotFoundError(toolgroup_id) - return tool_group + return tool_group # type: ignore[return-value] async def get_tool(self, tool_name: str) -> ToolDef: if tool_name in self.tool_to_toolgroup: @@ -123,7 +123,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): await self._index_tools(toolgroup) async def unregister_toolgroup(self, toolgroup_id: str) -> None: - await self.unregister_object(await self.get_tool_group(toolgroup_id)) + await self.unregister_object(await self.get_tool_group(toolgroup_id)) # type: ignore[arg-type] async def shutdown(self) -> None: pass diff --git a/src/llama_stack/core/routing_tables/vector_stores.py b/src/llama_stack/core/routing_tables/vector_stores.py index c6c80a01e..af3c62855 100644 --- a/src/llama_stack/core/routing_tables/vector_stores.py +++ b/src/llama_stack/core/routing_tables/vector_stores.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType +from llama_stack.core.access_control.datatypes import Action # Removed VectorStores import to avoid exposing public API from llama_stack.apis.vector_io.vector_io import ( @@ -67,11 +68,11 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): vector_store = VectorStoreWithOwner( identifier=vector_store_id, - type=ResourceType.vector_store.value, + type=ResourceType.vector_store, # type: ignore[arg-type] provider_id=provider_id, provider_resource_id=provider_vector_store_id, embedding_model=embedding_model, - embedding_dimension=embedding_dimension, + embedding_dimension=embedding_dimension or 384, # type: ignore[arg-type] vector_store_name=vector_store_name, ) await self.register_object(vector_store) @@ -81,9 +82,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, ) -> VectorStoreObject: - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) + return await provider.openai_retrieve_vector_store(vector_store_id) # type: ignore[no-any-return] async def openai_update_vector_store( self, @@ -92,9 +93,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): expires_after: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: - await self.assert_action_allowed("update", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( + return await provider.openai_update_vector_store( # type: ignore[no-any-return] vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -105,11 +106,11 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, ) -> VectorStoreDeleteResponse: - await self.assert_action_allowed("delete", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.DELETE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) result = await provider.openai_delete_vector_store(vector_store_id) await self.unregister_vector_store(vector_store_id) - return result + return result # type: ignore[no-any-return] async def unregister_vector_store(self, vector_store_id: str) -> None: """Remove the vector store from the routing table registry.""" @@ -131,9 +132,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): rewrite_query: bool | None = False, search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( + return await provider.openai_search_vector_store( # type: ignore[no-any-return] vector_store_id=vector_store_id, query=query, filters=filters, @@ -150,9 +151,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): attributes: dict[str, Any] | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: - await self.assert_action_allowed("update", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( + return await provider.openai_attach_file_to_vector_store( # type: ignore[no-any-return] vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -168,9 +169,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): before: str | None = None, filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( + return await provider.openai_list_files_in_vector_store( # type: ignore[no-any-return] vector_store_id=vector_store_id, limit=limit, order=order, @@ -184,9 +185,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): vector_store_id: str, file_id: str, ) -> VectorStoreFileObject: - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( + return await provider.openai_retrieve_vector_store_file( # type: ignore[no-any-return] vector_store_id=vector_store_id, file_id=file_id, ) @@ -196,9 +197,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): vector_store_id: str, file_id: str, ) -> VectorStoreFileContentsResponse: - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + return await provider.openai_retrieve_vector_store_file_contents( # type: ignore[no-any-return] vector_store_id=vector_store_id, file_id=file_id, ) @@ -209,9 +210,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): file_id: str, attributes: dict[str, Any], ) -> VectorStoreFileObject: - await self.assert_action_allowed("update", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( + return await provider.openai_update_vector_store_file( # type: ignore[no-any-return] vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -222,9 +223,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): vector_store_id: str, file_id: str, ) -> VectorStoreFileDeleteResponse: - await self.assert_action_allowed("delete", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.DELETE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( + return await provider.openai_delete_vector_store_file( # type: ignore[no-any-return] vector_store_id=vector_store_id, file_id=file_id, ) @@ -236,7 +237,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): attributes: dict[str, Any] | None = None, chunking_strategy: Any | None = None, ): - await self.assert_action_allowed("update", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_create_vector_store_file_batch( vector_store_id=vector_store_id, @@ -250,7 +251,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): batch_id: str, vector_store_id: str, ): - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_batch( batch_id=batch_id, @@ -267,7 +268,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): limit: int | None = 20, order: str | None = "desc", ): - await self.assert_action_allowed("read", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_list_files_in_vector_store_file_batch( batch_id=batch_id, @@ -284,7 +285,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): batch_id: str, vector_store_id: str, ): - await self.assert_action_allowed("update", "vector_store", vector_store_id) + await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id) provider = await self.get_provider_impl(vector_store_id) return await provider.openai_cancel_vector_store_file_batch( batch_id=batch_id, diff --git a/src/llama_stack/core/store/registry.py b/src/llama_stack/core/store/registry.py index 6ff9e575b..47723e26e 100644 --- a/src/llama_stack/core/store/registry.py +++ b/src/llama_stack/core/store/registry.py @@ -23,9 +23,9 @@ class DistributionRegistry(Protocol): async def initialize(self) -> None: ... - async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ... + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: ... - def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: ... async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...