mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 03:22:26 +00:00
more fixes
This commit is contained in:
parent
a66074a10e
commit
d76b3aa4d2
5 changed files with 37 additions and 25 deletions
|
|
@ -50,7 +50,8 @@ class SafetyRouter(Safety):
|
|||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
|
|
|
|||
|
|
@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
|
|
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
|
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
|
|||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
|
|
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
|
|||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
return await provider.openai_create_vector_store(
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
|
|
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
try:
|
||||
vector_store = await self.routing_table.get_provider_impl(
|
||||
vector_db.identifier
|
||||
).openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue