feat(registry): more flexible model lookup (#2859)

This PR updates model registration and lookup behavior to be slightly
more general / flexible. See
https://github.com/meta-llama/llama-stack/issues/2843 for more details.

Note that this change is backwards compatible given the design of the
`lookup_model()` method.

## Test Plan

Added unit tests
This commit is contained in:
Ashwin Bharambe 2025-07-22 15:22:48 -07:00 committed by GitHub
parent 9736f096f6
commit 3b83032555
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 265 additions and 75 deletions

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = []
for vector_db in vector_dbs:
try:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")