mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
fix(models)!: always prefix models with provider_id when registering (#3822)
**!!BREAKING CHANGE!!** The lookup is also straightforward -- we always look for this identifier and don't try to find a match for something without the provider_id prefix. Note that, this ideally means we need to update the `register_model()` API also (we should kill "identifier" from there) but I am not doing that as part of this PR. ## Test Plan Existing unit tests
This commit is contained in:
parent
f205ab6f6c
commit
f70aa99c97
10 changed files with 53 additions and 124 deletions
|
@ -245,25 +245,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
|
||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||
# first try to get the model by identifier
|
||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
logger.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||
)
|
||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||
# to iterate (given a lack of an efficient index on the KVStore)
|
||||
models = await routing_table.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
if not model:
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
||||
return matching_models[0]
|
||||
return model
|
||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
@ -104,15 +104,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# an identifier different than provider_model_id implies it is an alias, so that
|
||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||
# so as a general rule we must use the provider_id to disambiguate.
|
||||
|
||||
if model_id != provider_model_id:
|
||||
identifier = model_id
|
||||
else:
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
model = ModelWithOwner(
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
|
|
|
@ -435,7 +435,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
# First check if the model is pre-registered in the model store
|
||||
if hasattr(self, "model_store") and self.model_store:
|
||||
if await self.model_store.has_model(model):
|
||||
qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined]
|
||||
if await self.model_store.has_model(qualified_model):
|
||||
return True
|
||||
|
||||
# Then check the provider's dynamic model cache
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue