mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
151 lines
5.8 KiB
Python
151 lines
5.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import time
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
from llama_stack.distribution.datatypes import (
|
|
ModelWithOwner,
|
|
RegistryEntrySource,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
|
|
from .common import CommonRoutingTableImpl, lookup_model
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
listed_providers: set[str] = set()
|
|
|
|
async def refresh(self) -> None:
|
|
for provider_id, provider in self.impls_by_provider_id.items():
|
|
refresh = await provider.should_refresh_models()
|
|
if not (refresh or provider_id in self.listed_providers):
|
|
continue
|
|
|
|
try:
|
|
models = await provider.list_models()
|
|
except Exception as e:
|
|
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
|
continue
|
|
|
|
self.listed_providers.add(provider_id)
|
|
if models is None:
|
|
continue
|
|
|
|
await self.update_registered_models(provider_id, models)
|
|
|
|
async def list_models(self) -> ListModelsResponse:
|
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
|
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
models = await self.get_all_with_type("model")
|
|
openai_models = [
|
|
OpenAIModel(
|
|
id=model.identifier,
|
|
object="model",
|
|
created=int(time.time()),
|
|
owned_by="llama_stack",
|
|
)
|
|
for model in models
|
|
]
|
|
return OpenAIListModelsResponse(data=openai_models)
|
|
|
|
async def get_model(self, model_id: str) -> Model:
|
|
return await lookup_model(self, model_id)
|
|
|
|
async def get_provider_impl(self, model_id: str) -> Any:
|
|
model = await lookup_model(self, model_id)
|
|
return self.impls_by_provider_id[model.provider_id]
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_type: ModelType | None = None,
|
|
) -> Model:
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this model
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
|
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
|
)
|
|
|
|
provider_model_id = provider_model_id or model_id
|
|
metadata = metadata or {}
|
|
model_type = model_type or ModelType.llm
|
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
|
|
# an identifier different than provider_model_id implies it is an alias, so that
|
|
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
|
# so as a general rule we must use the provider_id to disambiguate.
|
|
|
|
if model_id != provider_model_id:
|
|
identifier = model_id
|
|
else:
|
|
identifier = f"{provider_id}/{provider_model_id}"
|
|
|
|
model = ModelWithOwner(
|
|
identifier=identifier,
|
|
provider_resource_id=provider_model_id,
|
|
provider_id=provider_id,
|
|
metadata=metadata,
|
|
model_type=model_type,
|
|
source=RegistryEntrySource.via_register_api,
|
|
)
|
|
registered_model = await self.register_object(model)
|
|
return registered_model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
existing_model = await self.get_model(model_id)
|
|
if existing_model is None:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
await self.unregister_object(existing_model)
|
|
|
|
async def update_registered_models(
|
|
self,
|
|
provider_id: str,
|
|
models: list[Model],
|
|
) -> None:
|
|
existing_models = await self.get_all_with_type("model")
|
|
|
|
# we may have an alias for the model registered by the user (or during initialization
|
|
# from run.yaml) that we need to keep track of
|
|
model_ids = {}
|
|
for model in existing_models:
|
|
if model.provider_id != provider_id:
|
|
continue
|
|
if model.source == RegistryEntrySource.via_register_api:
|
|
model_ids[model.provider_resource_id] = model.identifier
|
|
continue
|
|
|
|
logger.debug(f"unregistering model {model.identifier}")
|
|
await self.unregister_object(model)
|
|
|
|
for model in models:
|
|
if model.provider_resource_id in model_ids:
|
|
# avoid overwriting a non-provider-registered model entry
|
|
continue
|
|
|
|
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
|
await self.register_object(
|
|
ModelWithOwner(
|
|
identifier=model.identifier,
|
|
provider_resource_id=model.provider_resource_id,
|
|
provider_id=provider_id,
|
|
metadata=model.metadata,
|
|
model_type=model.model_type,
|
|
source=RegistryEntrySource.listed_from_provider,
|
|
)
|
|
)
|