mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:39:41 +00:00
feat(registry): more flexible model lookup
This commit is contained in:
parent
b5a6ecc331
commit
352bf3ec56
1 changed files with 31 additions and 10 deletions
|
|
@ -36,11 +36,24 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
# first try to get the model by identifier
|
||||||
|
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
if model is None:
|
if model is not None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||||
|
# to iterate (given a lack of an efficient index on the KVStore)
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||||
|
if len(matching_models) == 0:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
||||||
|
if len(matching_models) > 1:
|
||||||
|
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||||
|
|
||||||
|
return matching_models[0]
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
@ -49,24 +62,32 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
|
||||||
provider_model_id = model_id
|
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
if len(self.impls_by_provider_id) == 1:
|
if len(self.impls_by_provider_id) == 1:
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}"
|
||||||
)
|
)
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
provider_model_id = provider_model_id or model_id
|
||||||
if model_type is None:
|
metadata = metadata or {}
|
||||||
model_type = ModelType.llm
|
model_type = model_type or ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
|
||||||
|
# an identifier different than provider_model_id implies it is an alias, so that
|
||||||
|
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||||
|
# so as a general rule we must use the provider_id to disambiguate.
|
||||||
|
|
||||||
|
if model_id != provider_model_id:
|
||||||
|
identifier = model_id
|
||||||
|
else:
|
||||||
|
identifier = f"{provider_id}/{provider_model_id}"
|
||||||
|
|
||||||
model = ModelWithOwner(
|
model = ModelWithOwner(
|
||||||
identifier=model_id,
|
identifier=identifier,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue