From d76b3aa4d24bead98da4a0711c4f1c5fa5244512 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Jul 2025 10:45:39 -0700 Subject: [PATCH] more fixes --- llama_stack/distribution/routers/safety.py | 3 +- .../distribution/routers/tool_runtime.py | 13 ++++---- llama_stack/distribution/routers/vector_io.py | 14 +++++---- .../distribution/routing_tables/toolgroups.py | 2 +- .../distribution/routing_tables/vector_dbs.py | 30 ++++++++++++------- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 9761d2db0..26ee8e722 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -50,7 +50,8 @@ class SafetyRouter(Safety): params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( + provider = await self.routing_table.get_provider_impl(shield_id) + return await provider.run_shield( shield_id=shield_id, messages=messages, params=params, diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 285843dbc..5a40bc0c5 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime): query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) + provider = await self.routing_table.get_provider_impl("knowledge_search") + return await provider.query(content, vector_db_ids, query_config) async def insert( self, @@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + provider = await self.routing_table.get_provider_impl("insert_into_memory") + return await provider.insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime): async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + provider = await self.routing_table.get_provider_impl(tool_name) + return await provider.invoke_tool( tool_name=tool_name, kwargs=kwargs, ) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index a1dd66060..3d0996c49 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -104,7 +104,8 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -113,7 +114,8 @@ class VectorIORouter(VectorIO): params: dict[str, Any] | None = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.query_chunks(vector_db_id, query, params) # OpenAI Vector Stores API endpoints async def openai_create_vector_store( @@ -146,7 +148,8 @@ class VectorIORouter(VectorIO): provider_vector_db_id=vector_db_id, vector_db_name=name, ) - return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( + provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) + return await provider.openai_create_vector_store( name=name, file_ids=file_ids, expires_after=expires_after, @@ -172,9 +175,8 @@ class VectorIORouter(VectorIO): all_stores = [] for vector_db in vector_dbs: try: - vector_store = await self.routing_table.get_provider_impl( - vector_db.identifier - ).openai_retrieve_vector_store(vector_db.identifier) + provider = await self.routing_table.get_provider_impl(vector_db.identifier) + vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) all_stores.append(vector_store) except Exception as e: logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index 5df38ab64..22c4e109a 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolsResponse(data=all_tools) async def _index_tools(self, toolgroup: ToolGroup): - provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) # TODO: kill this Tool vs ToolDef distinction diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index de1458f4c..58ecf24da 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -92,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -102,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -114,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + result = await provider.openai_delete_vector_store(vector_store_id) await self.unregister_vector_db(vector_store_id) return result @@ -129,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -147,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -164,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -179,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -190,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileContentsResponse: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -202,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): attributes: dict[str, Any], ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -214,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, )