refactor lookup_model out so it can be used by vector dbs routing table

This commit is contained in:
Ashwin Bharambe 2025-07-22 10:08:29 -07:00
parent d3dee496ec
commit 50d16dc707
3 changed files with 26 additions and 22 deletions

View file

@ -6,6 +6,7 @@
from typing import Any
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
@ -235,3 +236,23 @@ class CommonRoutingTableImpl(RoutingTable):
]
return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]

View file

@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
@ -36,23 +36,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await self.get_object_by_identifier("model", model_id)
if model is not None:
return model
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await self.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]
return await lookup_model(self, model_id)
async def register_model(
self,

View file

@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
model = await lookup_model(self, embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: