feat(registry): make the Stack query providers for model listing (#2862)

This flips #2823 and #2805 by making the Stack periodically query the
providers for models rather than the providers going behind the back and
calling "register" on to the registry themselves. This also adds support
for model listing for all other providers via `ModelRegistryHelper`.
Once this is done, we do not need to manually list or register models
via `run.yaml` and it will remove both noise and annoyance (setting
`INFERENCE_MODEL` environment variables, for example) from the new user
experience.

In addition, it adds a configuration variable `allowed_models` which can
be used to optionally restrict the set of models exposed from a
provider.
This commit is contained in:
Ashwin Bharambe 2025-07-24 10:39:53 -07:00 committed by GitHub
parent 537dc693ee
commit 1463b79218
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 429 additions and 147 deletions

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
# avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)