llama-stack-mirror/llama_stack/distribution/routing_tables/models.py
Ashwin Bharambe 3b83032555
feat(registry): more flexible model lookup (#2859)
This PR updates model registration and lookup behavior to be slightly
more general / flexible. See
https://github.com/meta-llama/llama-stack/issues/2843 for more details.

Note that this change is backwards compatible given the design of the
`lookup_model()` method.

## Test Plan

Added unit tests
2025-07-22 15:22:48 -07:00

127 lines
5.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
return self.impls_by_provider_id[model.provider_id]
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
)
)