mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
feat(registry): more flexible model lookup (#2859)
This PR updates model registration and lookup behavior to be slightly more general / flexible. See https://github.com/meta-llama/llama-stack/issues/2843 for more details. Note that this change is backwards compatible given the design of the `lookup_model()` method. ## Test Plan Added unit tests
This commit is contained in:
parent
9736f096f6
commit
3b83032555
15 changed files with 265 additions and 75 deletions
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
|
||||
provider_model_id = provider_model_id or model_id
|
||||
metadata = metadata or {}
|
||||
model_type = model_type or ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# an identifier different than provider_model_id implies it is an alias, so that
|
||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||
# so as a general rule we must use the provider_id to disambiguate.
|
||||
|
||||
if model_id != provider_model_id:
|
||||
identifier = model_id
|
||||
else:
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue