From 352bf3ec56e8da91076a242a7c75dee7fed19d4d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Jul 2025 09:08:51 -0700 Subject: [PATCH] feat(registry): more flexible model lookup --- .../distribution/routing_tables/models.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 9a9db7257..dc3dcf5b2 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -36,10 +36,23 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: + # first try to get the model by identifier + # this works if model_id is an alias or is of the form provider_id/provider_model_id model = await self.get_object_by_identifier("model", model_id) - if model is None: + if model is not None: + return model + + # if not found, this means model_id is an unscoped provider_model_id, we need + # to iterate (given a lack of an efficient index on the KVStore) + models = await self.get_all_with_type("model") + matching_models = [m for m in models if m.provider_resource_id == model_id] + if len(matching_models) == 0: raise ValueError(f"Model '{model_id}' not found") - return model + + if len(matching_models) > 1: + raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") + + return matching_models[0] async def register_model( self, @@ -49,24 +62,32 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: - if provider_model_id is None: - provider_model_id = model_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this model if len(self.impls_by_provider_id) == 1: provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}" ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm + + provider_model_id = provider_model_id or model_id + metadata = metadata or {} + model_type = model_type or ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") + + # an identifier different than provider_model_id implies it is an alias, so that + # becomes the globally unique identifier. otherwise provider_model_ids can conflict, + # so as a general rule we must use the provider_id to disambiguate. + + if model_id != provider_model_id: + identifier = model_id + else: + identifier = f"{provider_id}/{provider_model_id}" + model = ModelWithOwner( - identifier=model_id, + identifier=identifier, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata,