mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat(registry): make the Stack query providers for model listing (#2862)
This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
This commit is contained in:
parent
537dc693ee
commit
1463b79218
23 changed files with 429 additions and 147 deletions
|
@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
|
@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
if not (refresh or provider_id in self.listed_providers):
|
||||
continue
|
||||
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
if models is None:
|
||||
continue
|
||||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
|
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_llm_models(
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
|
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
# we leave embeddings models alone because often we don't get metadata
|
||||
# (embedding dimension, etc.) from the provider
|
||||
if model.provider_id == provider_id and model.model_type == ModelType.llm:
|
||||
if model.provider_id != provider_id:
|
||||
continue
|
||||
if model.source == RegistryEntrySource.via_register_api:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.model_type != ModelType.llm:
|
||||
continue
|
||||
if model.provider_resource_id in model_ids:
|
||||
model.identifier = model_ids[model.provider_resource_id]
|
||||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
|
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue