refactor(mypy): fix all type errors in core routing tables

Fix 163 mypy errors across 8 routing table files by addressing:
- union-attr errors from RoutedProtocol union type access
- arg-type mismatches between RoutableObjectWithProvider union and ProtectedResource protocol
- return-value incompatibilities between specific types and union types
- Action enum usage instead of string literals for access control
- Protocol signature updates for DistributionRegistry.get/get_cached methods
- Variable naming conflicts in nested loops (models.py)
- Override signature compatibility in benchmarks.py

Remove routing_tables/ exclusion from pyproject.toml.
This commit is contained in:
Ashwin Bharambe 2025-10-29 12:40:37 -07:00
parent b90c6a2c8b
commit 856b503226
10 changed files with 100 additions and 98 deletions

View file

@ -274,7 +274,6 @@ exclude = [
"^src/llama_stack/core/client\\.py$",
"^src/llama_stack/core/request_headers\\.py$",
"^src/llama_stack/core/routers/",
"^src/llama_stack/core/routing_tables/",
"^src/llama_stack/core/server/endpoints\\.py$",
"^src/llama_stack/core/server/server\\.py$",
"^src/llama_stack/core/stack\\.py$",

View file

@ -19,15 +19,15 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) # type: ignore[arg-type]
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
return benchmark # type: ignore[return-value]
async def register_benchmark(
async def register_benchmark( # type: ignore[override]
self,
benchmark_id: str,
dataset_id: str,
@ -59,4 +59,4 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def unregister_benchmark(self, benchmark_id: str) -> None:
existing_benchmark = await self.get_benchmark(benchmark_id)
await self.unregister_object(existing_benchmark)
await self.unregister_object(existing_benchmark) # type: ignore[arg-type]

View file

@ -27,7 +27,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
return p.__provider_spec__.api # type: ignore[no-any-return]
# TODO: this should return the registered object for all APIs
@ -37,19 +37,19 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
return await p.register_model(obj) # type: ignore[no-any-return]
elif api == Api.safety:
return await p.register_shield(obj)
return await p.register_shield(obj) # type: ignore[no-any-return]
elif api == Api.vector_io:
return await p.register_vector_store(obj)
return await p.register_vector_store(obj) # type: ignore[no-any-return]
elif api == Api.datasetio:
return await p.register_dataset(obj)
return await p.register_dataset(obj) # type: ignore[no-any-return]
elif api == Api.scoring:
return await p.register_scoring_function(obj)
return await p.register_scoring_function(obj) # type: ignore[no-any-return]
elif api == Api.eval:
return await p.register_benchmark(obj)
return await p.register_benchmark(obj) # type: ignore[no-any-return]
elif api == Api.tool_runtime:
return await p.register_toolgroup(obj)
return await p.register_toolgroup(obj) # type: ignore[no-any-return]
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -57,19 +57,19 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_store(obj.identifier)
await p.unregister_vector_store(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
await p.unregister_model(obj.identifier)
elif api == Api.safety:
return await p.unregister_shield(obj.identifier)
await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
await p.unregister_dataset(obj.identifier)
elif api == Api.eval:
return await p.unregister_benchmark(obj.identifier)
await p.unregister_benchmark(obj.identifier)
elif api == Api.scoring:
return await p.unregister_scoring_function(obj.identifier)
await p.unregister_scoring_function(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -104,25 +104,25 @@ class CommonRoutingTableImpl(RoutingTable):
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
p.model_store = self # type: ignore[union-attr]
elif api == Api.safety:
p.shield_store = self
p.shield_store = self # type: ignore[union-attr]
elif api == Api.vector_io:
p.vector_store_store = self
p.vector_store_store = self # type: ignore[union-attr]
elif api == Api.datasetio:
p.dataset_store = self
p.dataset_store = self # type: ignore[union-attr]
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
p.scoring_function_store = self # type: ignore[union-attr]
scoring_functions = await p.list_scoring_functions() # type: ignore[union-attr]
await add_objects(scoring_functions, pid, ScoringFnWithOwner)
elif api == Api.eval:
p.benchmark_store = self
p.benchmark_store = self # type: ignore[union-attr]
elif api == Api.tool_runtime:
p.tool_store = self
p.tool_store = self # type: ignore[union-attr]
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
await p.shutdown() # type: ignore[union-attr]
async def refresh(self) -> None:
pass
@ -180,7 +180,7 @@ class CommonRoutingTableImpl(RoutingTable):
return None
# Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
if not is_action_allowed(self.policy, Action.READ, obj, get_authenticated_user()): # type: ignore[arg-type]
logger.debug(f"Access denied to {type} '{identifier}'")
return None
@ -188,8 +188,8 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
user = get_authenticated_user()
if not is_action_allowed(self.policy, "delete", obj, user):
raise AccessDeniedError("delete", obj, user)
if not is_action_allowed(self.policy, Action.DELETE, obj, user): # type: ignore[arg-type]
raise AccessDeniedError(Action.DELETE, obj, user) # type: ignore[arg-type]
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
@ -205,8 +205,8 @@ class CommonRoutingTableImpl(RoutingTable):
# If object supports access control but no attributes set, use creator's attributes
creator = get_authenticated_user()
if not is_action_allowed(self.policy, "create", obj, creator):
raise AccessDeniedError("create", obj, creator)
if not is_action_allowed(self.policy, Action.CREATE, obj, creator): # type: ignore[arg-type]
raise AccessDeniedError(Action.CREATE, obj, creator) # type: ignore[arg-type]
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
@ -214,8 +214,8 @@ class CommonRoutingTableImpl(RoutingTable):
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
await self.dist_registry.register(registered_obj) # type: ignore[arg-type]
return registered_obj # type: ignore[return-value]
else:
await self.dist_registry.register(obj)
return obj
@ -231,8 +231,8 @@ class CommonRoutingTableImpl(RoutingTable):
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
if not is_action_allowed(self.policy, action, obj, user): # type: ignore[arg-type]
raise AccessDeniedError(action, obj, user) # type: ignore[arg-type]
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
@ -241,7 +241,9 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
obj
for obj in filtered_objs
if is_action_allowed(self.policy, Action.READ, obj, get_authenticated_user()) # type: ignore[arg-type]
]
return filtered_objs
@ -251,4 +253,4 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
model = await routing_table.get_object_by_identifier("model", model_id)
if not model:
raise ModelNotFoundError(model_id)
return model
return model # type: ignore[return-value]

View file

@ -31,13 +31,13 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) # type: ignore[arg-type]
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise DatasetNotFoundError(dataset_id)
return dataset
return dataset # type: ignore[return-value]
async def register_dataset(
self,
@ -77,7 +77,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
dataset = DatasetWithOwner(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
provider_id=provider_id, # type: ignore[arg-type]
purpose=purpose,
source=source,
metadata=metadata,
@ -88,4 +88,4 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
await self.unregister_object(dataset)
await self.unregister_object(dataset) # type: ignore[arg-type]

View file

@ -25,13 +25,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
refresh = await provider.should_refresh_models() # type: ignore[union-attr]
refresh = refresh or provider_id not in self.listed_providers
if not refresh:
continue
try:
models = await provider.list_models()
models = await provider.list_models() # type: ignore[union-attr]
except Exception as e:
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue
@ -43,7 +43,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
return ListModelsResponse(data=await self.get_all_with_type("model")) # type: ignore[arg-type]
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
@ -61,8 +61,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def get_model(self, model_id: str) -> Model:
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
model = await lookup_model(self, routing_key)
if model.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
@ -114,13 +114,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
return registered_model # type: ignore[return-value]
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model)
await self.unregister_object(existing_model) # type: ignore[arg-type]
async def update_registered_models(
self,
@ -142,22 +142,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.provider_resource_id in model_ids:
for provider_model in models:
if provider_model.provider_resource_id in model_ids:
# avoid overwriting a non-provider-registered model entry
continue
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
if provider_model.identifier == provider_model.provider_resource_id:
provider_model.identifier = f"{provider_id}/{provider_model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
logger.debug(f"registering model {provider_model.identifier} ({provider_model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
identifier=provider_model.identifier,
provider_resource_id=provider_model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
metadata=provider_model.metadata,
model_type=provider_model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)

View file

@ -24,13 +24,13 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) # type: ignore[arg-type]
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
return scoring_fn # type: ignore[return-value]
async def register_scoring_function(
self,
@ -63,4 +63,4 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
await self.unregister_object(existing_scoring_fn)
await self.unregister_object(existing_scoring_fn) # type: ignore[arg-type]

View file

@ -20,13 +20,13 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) # type: ignore[arg-type]
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
return shield # type: ignore[return-value]
async def register_shield(
self,
@ -58,4 +58,4 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def unregister_shield(self, identifier: str) -> None:
existing_shield = await self.get_shield(identifier)
await self.unregister_object(existing_shield)
await self.unregister_object(existing_shield) # type: ignore[arg-type]

View file

@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
toolgroups = await self.get_all_with_type("tool_group") # type: ignore[assignment]
all_tools = []
for toolgroup in toolgroups:
@ -83,13 +83,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) # type: ignore[arg-type]
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ToolGroupNotFoundError(toolgroup_id)
return tool_group
return tool_group # type: ignore[return-value]
async def get_tool(self, tool_name: str) -> ToolDef:
if tool_name in self.tool_to_toolgroup:
@ -123,7 +123,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self._index_tools(toolgroup)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
await self.unregister_object(await self.get_tool_group(toolgroup_id))
await self.unregister_object(await self.get_tool_group(toolgroup_id)) # type: ignore[arg-type]
async def shutdown(self) -> None:
pass

View file

@ -9,6 +9,7 @@ from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.core.access_control.datatypes import Action
# Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
@ -67,11 +68,11 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
vector_store = VectorStoreWithOwner(
identifier=vector_store_id,
type=ResourceType.vector_store.value,
type=ResourceType.vector_store, # type: ignore[arg-type]
provider_id=provider_id,
provider_resource_id=provider_vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
embedding_dimension=embedding_dimension or 384, # type: ignore[arg-type]
vector_store_name=vector_store_name,
)
await self.register_object(vector_store)
@ -81,9 +82,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id) # type: ignore[no-any-return]
async def openai_update_vector_store(
self,
@ -92,9 +93,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
return await provider.openai_update_vector_store( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -105,11 +106,11 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.DELETE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_store(vector_store_id)
return result
return result # type: ignore[no-any-return]
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Remove the vector store from the routing table registry."""
@ -131,9 +132,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
return await provider.openai_search_vector_store( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -150,9 +151,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
return await provider.openai_attach_file_to_vector_store( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -168,9 +169,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
return await provider.openai_list_files_in_vector_store( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -184,9 +185,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
return await provider.openai_retrieve_vector_store_file( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -196,9 +197,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
return await provider.openai_retrieve_vector_store_file_contents( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -209,9 +210,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
return await provider.openai_update_vector_store_file( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -222,9 +223,9 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.DELETE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
return await provider.openai_delete_vector_store_file( # type: ignore[no-any-return]
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -236,7 +237,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
@ -250,7 +251,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
@ -267,7 +268,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.READ, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
@ -284,7 +285,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_store", vector_store_id)
await self.assert_action_allowed(Action.UPDATE, "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,

View file

@ -23,9 +23,9 @@ class DistributionRegistry(Protocol):
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...