diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index bbe0113e9..15325276f 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -6,6 +6,7 @@ from typing import Any +from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed @@ -235,3 +236,23 @@ class CommonRoutingTableImpl(RoutingTable): ] return filtered_objs + + +async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: + # first try to get the model by identifier + # this works if model_id is an alias or is of the form provider_id/provider_model_id + model = await routing_table.get_object_by_identifier("model", model_id) + if model is not None: + return model + + # if not found, this means model_id is an unscoped provider_model_id, we need + # to iterate (given a lack of an efficient index on the KVStore) + models = await routing_table.get_all_with_type("model") + matching_models = [m for m in models if m.provider_resource_id == model_id] + if len(matching_models) == 0: + raise ValueError(f"Model '{model_id}' not found") + + if len(matching_models) > 1: + raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") + + return matching_models[0] diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index d6fa7ab6b..2f3ce8193 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -36,23 +36,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: - # first try to get the model by identifier - # this works if model_id is an alias or is of the form provider_id/provider_model_id - model = await self.get_object_by_identifier("model", model_id) - if model is not None: - return model - - # if not found, this means model_id is an unscoped provider_model_id, we need - # to iterate (given a lack of an efficient index on the KVStore) - models = await self.get_all_with_type("model") - matching_models = [m for m in models if m.provider_resource_id == model_id] - if len(matching_models) == 0: - raise ValueError(f"Model '{model_id}' not found") - - if len(matching_models) > 1: - raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") - - return matching_models[0] + return await lookup_model(self, model_id) async def register_model( self, diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index b4e60c625..de1458f4c 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id + provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): ) else: raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) + model = await lookup_model(self, embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: