feat(registry): make the Stack query providers for model listing (#2862)

This flips #2823 and #2805 by making the Stack periodically query the
providers for models rather than the providers going behind the back and
calling "register" on to the registry themselves. This also adds support
for model listing for all other providers via `ModelRegistryHelper`.
Once this is done, we do not need to manually list or register models
via `run.yaml` and it will remove both noise and annoyance (setting
`INFERENCE_MODEL` environment variables, for example) from the new user
experience.

In addition, it adds a configuration variable `allowed_models` which can
be used to optionally restrict the set of models exposed from a
provider.
This commit is contained in:
Ashwin Bharambe 2025-07-24 10:39:53 -07:00 committed by GitHub
parent 537dc693ee
commit 1463b79218
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 429 additions and 147 deletions

View file

@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |

View file

@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |

View file

@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum): class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str] RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel): class User(BaseModel):
principal: str principal: str
# further attributes that may be used for access control decisions # further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource.""" resource. This can be used to constrain access to the resource."""
owner: User | None = None owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects

View file

@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal: if not self.skip_logger_removal:
self._remove_root_logger_handlers() self._remove_root_logger_handlers()
return self.loop.run_until_complete(self.async_client.initialize()) # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def _remove_root_logger_handlers(self):
""" """

View file

@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable from .datasets import DatasetsRoutingTable
@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj) await self.dist_registry.register(registered_obj)
return registered_obj return registered_obj
else: else:
await self.dist_registry.register(obj) await self.dist_registry.register(obj)
return obj return obj

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type("model"))
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type, model_type=model_type,
source=RegistryEntrySource.via_register_api,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_llm_models( async def update_registered_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of
model_ids = {} model_ids = {}
for model in existing_models: for model in existing_models:
# we leave embeddings models alone because often we don't get metadata if model.provider_id != provider_id:
# (embedding dimension, etc.) from the provider continue
if model.provider_id == provider_id and model.model_type == ModelType.llm: if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
continue
logger.debug(f"unregistering model {model.identifier}") logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model) await self.unregister_object(model)
for model in models: for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids: if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id] # avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object( await self.register_object(
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=model.metadata, metadata=model.metadata,
model_type=model.model_type, model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
) )
) )

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import importlib.resources import importlib.resources
import os import os
import re import re
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -90,6 +92,9 @@ RESOURCES = [
] ]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES: for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc) objects = getattr(run_config, rsrc)
@ -324,9 +329,33 @@ async def construct_stack(
add_internal_implementations(impls, run_config) add_internal_implementations(impls, run_config)
await register_resources(run_config, impls) await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
task.add_done_callback(cb)
return impls return impls
async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables:
await routing_table.refresh()
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"

View file

@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ... async def register_shield(self, shield: Shield) -> None: ...

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None: def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config self.config = config
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model

View file

@ -6,13 +6,14 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class FireworksImplConfig(BaseModel): class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field( url: str = Field(
default="https://api.fireworks.ai/inference/v1", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",

View file

@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") refresh_models: bool = Field(
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") default=False,
description="Whether to refresh models periodically",
)
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -98,14 +98,16 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config self.config = config
self._client = None self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None self._openai_client = None
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
if self._client is None: # ollama client attaches itself to the current event loop (sadly?)
self._client = AsyncClient(host=self.config.url) loop = asyncio.get_running_loop()
return self._client if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
return self._clients[loop]
@property @property
def openai_client(self) -> AsyncOpenAI: def openai_client(self) -> AsyncOpenAI:
@ -121,45 +123,50 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
) )
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
logger.debug("ollama starting background model refresh task") return self.config.refresh_models
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
while True:
try:
response = await self.client.list() response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
continue
models = [] # always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models: for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm # kill embedding models since we don't know dimensions for them
if model_type == ModelType.embedding: if m.details.family in ["bert"]:
continue continue
models.append( models.append(
Model( Model(
@ -167,13 +174,10 @@ class OllamaInferenceAdapter(
provider_resource_id=m.model, provider_resource_id=m.model,
provider_id=provider_id, provider_id=provider_id,
metadata={}, metadata={},
model_type=model_type, model_type=ModelType.llm,
) )
) )
await self.model_store.update_registered_llm_models(provider_id, models) return models
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """
@ -190,12 +194,7 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done(): self._clients.clear()
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -6,13 +6,14 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class TogetherImplConfig(BaseModel): class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field( url: str = Field(
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",

View file

@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False, default=False,
description="Whether to refresh models periodically", description="Whether to refresh models periodically",
) )
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider # automatically set by the resolver when instantiating the provider
__provider_id__: str __provider_id__: str
model_store: ModelStore | None = None model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -301,43 +299,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: pass
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
self._refresh_task = asyncio.create_task(self._refresh_models()) return self.config.refresh_models
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client() self._lazy_initialize_client()
assert self.client is not None # mypy assert self.client is not None # mypy
while True:
try:
models = [] models = []
async for m in self.client.models.list(): async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models model_type = ModelType.llm # unclear how to determine embedding vs. llm models
@ -345,21 +314,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
Model( Model(
identifier=m.id, identifier=m.id,
provider_resource_id=m.id, provider_resource_id=m.id,
provider_id=provider_id, provider_id=self.__provider_id__,
metadata={}, metadata={},
model_type=model_type, model_type=model_type,
) )
) )
await self.model_store.update_registered_llm_models(provider_id, models) return models
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self._refresh_task: pass
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
# TODO: this class is more confusing than useful right now. We need to make it # TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class. # more closer to the Model class.
class ProviderModelEntry(BaseModel): class ProviderModelEntry(BaseModel):
@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate): class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: list[ProviderModelEntry]): __provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.allowed_models = allowed_models
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {} self.provider_id_to_llama_model_map = {}
for entry in model_entries: for entry in model_entries:
@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
if self.allowed_models and id not in self.allowed_models:
continue
models.append(
Model(
model_id=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None: def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None) return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.datatypes import RegistryEntrySource
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
async def unregister_model(self, model_id: str): async def unregister_model(self, model_id: str):
return model_id return model_id
async def should_refresh_models(self):
return False
async def list_models(self):
return [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
async def shutdown(self):
pass
class SafetyImpl(Impl): class SafetyImpl(Impl):
def __init__(self): def __init__(self):
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
raise AssertionError("Should have raised ValueError for non-existent model") raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e: except ValueError as e:
assert "not found" in str(e) assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
"""Test that models registered via register_model get default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model via register_model (should get default source)
await table.register_model(model_id="user-model", provider_id="test_provider")
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.source == RegistryEntrySource.via_register_api
assert model.identifier == "test_provider/user-model"
# Cleanup
await table.shutdown()
async def test_models_source_tracking_provider(cached_disk_dist_registry):
"""Test that models registered via update_registered_models get provider source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Simulate provider refresh by calling update_registered_models
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
await table.update_registered_models("test_provider", provider_models)
models = await table.list_models()
assert len(models.data) == 2
# All models should have provider source
for model in models.data:
assert model.source == RegistryEntrySource.listed_from_provider
assert model.provider_id == "test_provider"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
"""Test that provider refresh preserves user-registered models with default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# First register a user model with same provider_resource_id as provider will later provide
await table.register_model(
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
)
# Verify user model is registered with default source
models = await table.list_models()
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="different-model",
provider_resource_id="different-model",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models)
# Verify user model with alias is preserved, but provider added new model
models = await table.list_models()
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.provider_resource_id == "provider-model-1"
assert provider_model is not None
assert provider_model.source == RegistryEntrySource.listed_from_provider
assert provider_model.provider_resource_id == "different-model"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
"""Test that provider refresh removes old provider models but keeps default ones."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a user model
await table.register_model(model_id="user-model", provider_id="test_provider")
# Add some provider models
provider_models_v1 = [
Model(
identifier="provider-model-old",
provider_resource_id="provider-model-old",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v1)
# Verify we have both user and provider models
models = await table.list_models()
assert len(models.data) == 2
# Now update with new provider models (should remove old provider models)
provider_models_v2 = [
Model(
identifier="provider-model-new",
provider_resource_id="provider-model-new",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v2)
# Should have user model + new provider model, old provider model gone
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider
# Cleanup
await table.shutdown()