more fixes

This commit is contained in:
Ashwin Bharambe 2025-07-22 10:45:39 -07:00
parent a66074a10e
commit d76b3aa4d2
5 changed files with 37 additions and 25 deletions

View file

@ -50,7 +50,8 @@ class SafetyRouter(Safety):
params: dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
params=params, params=params,

View file

@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
query_config: RAGQueryConfig | None = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( provider = await self.routing_table.get_provider_impl("knowledge_search")
content, vector_db_ids, query_config return await provider.query(content, vector_db_ids, query_config)
)
async def insert( async def insert(
self, self,
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl("insert_into_memory").insert( provider = await self.routing_table.get_provider_impl("insert_into_memory")
documents, vector_db_id, chunk_size_in_tokens return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
)
def __init__( def __init__(
self, self,
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
) )

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
self, self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints
async def openai_create_vector_store( async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id, provider_vector_db_id=vector_db_id,
vector_db_name=name, vector_db_name=name,
) )
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name, name=name,
file_ids=file_ids, file_ids=file_ids,
expires_after=expires_after, expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = [] all_stores = []
for vector_db in vector_dbs: for vector_db in vector_dbs:
try: try:
vector_store = await self.routing_table.get_provider_impl( provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_db.identifier vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
).openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store) all_stores.append(vector_store)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")

View file

@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolsResponse(data=all_tools) return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup): async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction # TODO: kill this Tool vs ToolDef distinction

View file

@ -92,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store( async def openai_update_vector_store(
self, self,
@ -102,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
name=name, name=name,
expires_after=expires_after, expires_after=expires_after,
@ -114,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreDeleteResponse: ) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_db", vector_store_id)
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id) await self.unregister_vector_db(vector_store_id)
return result return result
@ -129,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_search_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
query=query, query=query,
filters=filters, filters=filters,
@ -147,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -164,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
limit=limit, limit=limit,
order=order, order=order,
@ -179,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -190,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -202,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -214,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_db", vector_store_id)
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )