From 1463b792182a7df57a99f81f8048090d8352d994 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Jul 2025 10:39:53 -0700 Subject: [PATCH] feat(registry): make the Stack query providers for model listing (#2862) This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider. --- .../providers/inference/remote_fireworks.md | 1 + .../providers/inference/remote_ollama.md | 3 +- .../providers/inference/remote_together.md | 1 + .../source/providers/inference/remote_vllm.md | 1 - llama_stack/apis/inference/inference.py | 6 - llama_stack/distribution/datatypes.py | 6 + llama_stack/distribution/library_client.py | 8 +- .../distribution/routing_tables/common.py | 4 +- .../distribution/routing_tables/models.py | 42 +++- llama_stack/distribution/stack.py | 29 +++ llama_stack/providers/datatypes.py | 11 + .../inference/meta_reference/inference.py | 6 + .../sentence_transformers.py | 19 ++ .../remote/inference/fireworks/config.py | 5 +- .../remote/inference/fireworks/fireworks.py | 2 +- .../remote/inference/ollama/config.py | 6 +- .../remote/inference/ollama/ollama.py | 117 ++++++----- .../remote/inference/together/config.py | 5 +- .../remote/inference/together/together.py | 2 +- .../providers/remote/inference/vllm/config.py | 4 - .../providers/remote/inference/vllm/vllm.py | 73 ++----- .../utils/inference/model_registry.py | 33 ++- .../routers/test_routing_tables.py | 192 ++++++++++++++++++ 23 files changed, 429 insertions(+), 147 deletions(-) diff --git a/docs/source/providers/inference/remote_fireworks.md b/docs/source/providers/inference/remote_fireworks.md index 351586c34..862860c29 100644 --- a/docs/source/providers/inference/remote_fireworks.md +++ b/docs/source/providers/inference/remote_fireworks.md @@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index 23b8f87a2..f9f0a7622 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | refresh and re-register models periodically | -| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_together.md b/docs/source/providers/inference/remote_together.md index f33ff42f2..d1fe3e82b 100644 --- a/docs/source/providers/inference/remote_together.md +++ b/docs/source/providers/inference/remote_together.md @@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 5291199a4..172d35873 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b2bb8a8e6..222099064 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_llm_models( - self, - provider_id: str, - models: list[Model], - ) -> None: ... - class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ead1331f3..90b269452 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2 RoutingKey = str | list[str] +class RegistryEntrySource(StrEnum): + via_register_api = "via_register_api" + listed_from_provider = "listed_from_provider" + + class User(BaseModel): principal: str # further attributes that may be used for access control decisions @@ -50,6 +55,7 @@ class ResourceWithOwner(Resource): resource. This can be used to constrain access to the resource.""" owner: User | None = None + source: RegistryEntrySource = RegistryEntrySource.via_register_api # Use the extended Resource for all routable objects diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 6503c13b2..5044fd8c8 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return self.loop.run_until_complete(self.async_client.initialize()) + # use a new event loop to avoid interfering with the main event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self.async_client.initialize()) + finally: + asyncio.set_event_loop(None) def _remove_root_logger_handlers(self): """ diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 2f6ac90bb..caf0780fd 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() + async def refresh(self) -> None: + pass + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable @@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable): if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj - else: await self.dist_registry.register(obj) return obj diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index f2787b308..022c3dd40 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -10,6 +10,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, + RegistryEntrySource, ) from llama_stack.log import get_logger @@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): + listed_providers: set[str] = set() + + async def refresh(self) -> None: + for provider_id, provider in self.impls_by_provider_id.items(): + refresh = await provider.should_refresh_models() + if not (refresh or provider_id in self.listed_providers): + continue + + try: + models = await provider.list_models() + except Exception as e: + logger.exception(f"Model refresh failed for provider {provider_id}: {e}") + continue + + self.listed_providers.add(provider_id) + if models is None: + continue + + await self.update_registered_models(provider_id, models) + async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) @@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, + source=RegistryEntrySource.via_register_api, ) registered_model = await self.register_object(model) return registered_model @@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_llm_models( + async def update_registered_models( self, provider_id: str, models: list[Model], @@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - # we leave embeddings models alone because often we don't get metadata - # (embedding dimension, etc.) from the provider - if model.provider_id == provider_id and model.model_type == ModelType.llm: + if model.provider_id != provider_id: + continue + if model.source == RegistryEntrySource.via_register_api: model_ids[model.provider_resource_id] = model.identifier - logger.debug(f"unregistering model {model.identifier}") - await self.unregister_object(model) + continue + + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) for model in models: - if model.model_type != ModelType.llm: - continue if model.provider_resource_id in model_ids: - model.identifier = model_ids[model.provider_resource_id] + # avoid overwriting a non-provider-registered model entry + continue logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( @@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=model.metadata, model_type=model.model_type, + source=RegistryEntrySource.listed_from_provider, ) ) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index d7270156a..57bc4cd5f 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import importlib.resources import os import re @@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls +from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -90,6 +92,9 @@ RESOURCES = [ ] +REGISTRY_REFRESH_INTERVAL_SECONDS = 300 + + async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) @@ -324,9 +329,33 @@ async def construct_stack( add_internal_implementations(impls, run_config) await register_resources(run_config, impls) + + task = asyncio.create_task(refresh_registry(impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + task.add_done_callback(cb) return impls +async def refresh_registry(impls: dict[Api, Any]): + routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] + while True: + for routing_table in routing_tables: + await routing_table.refresh() + + await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) + + def get_stack_run_config_from_template(template: str) -> StackRunConfig: template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 424380324..005bfbab8 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol): async def unregister_model(self, model_id: str) -> None: ... + # the Stack router will query each provider for their list of models + # if a `refresh_interval_seconds` is provided, this method will be called + # periodically to refresh the list of models + # + # NOTE: each model returned will be registered with the model registry. this means + # a callback to the `register_model()` method will be made. this is duplicative and + # may be removed in the future. + async def list_models(self) -> list[Model] | None: ... + + async def should_refresh_models(self) -> bool: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..88d7a98ec 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return None + async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 890c526f5..fea8a8189 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl( InferenceProvider, ModelsProtocolPrivate, ): + __provider_id__: str + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: self.config = config @@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return [ + Model( + identifier="all-MiniLM-L6-v2", + provider_resource_id="all-MiniLM-L6-v2", + provider_id=self.__provider_id__, + metadata={ + "embedding_dimension": 384, + }, + model_type=ModelType.embedding, + ), + ] + async def register_model(self, model: Model) -> Model: return model diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 072d558f4..b23f2d31b 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class FireworksImplConfig(BaseModel): +class FireworksImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 1c82ff3a8..c76aa39f3 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index ae261f47c..ce13f0d83 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") - refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 76d789d07..ba20185d3 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -98,14 +98,16 @@ class OllamaInferenceAdapter( def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.config = config - self._client = None + self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._openai_client = None @property def client(self) -> AsyncClient: - if self._client is None: - self._client = AsyncClient(host=self.config.url) - return self._client + # ollama client attaches itself to the current event loop (sadly?) + loop = asyncio.get_running_loop() + if loop not in self._clients: + self._clients[loop] = AsyncClient(host=self.config.url) + return self._clients[loop] @property def openai_client(self) -> AsyncOpenAI: @@ -121,59 +123,61 @@ class OllamaInferenceAdapter( "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" ) - if self.config.refresh_models: - logger.debug("ollama starting background model refresh task") - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - if task.cancelled(): - import traceback - - logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - logger.error(f"ollama background refresh task died: {task.exception()}") - else: - logger.error("ollama background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - # Wait for model store to be available (with timeout) - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ - while True: - try: - response = await self.client.list() - except Exception as e: - logger.warning(f"Failed to list models: {str(e)}") - await asyncio.sleep(self.config.refresh_models_interval) + response = await self.client.list() + + # always add the two embedding models which can be pulled on demand + models = [ + Model( + identifier="all-minilm:l6-v2", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + # add all-minilm alias + Model( + identifier="all-minilm", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + Model( + identifier="nomic-embed-text", + provider_resource_id="nomic-embed-text", + provider_id=provider_id, + metadata={ + "embedding_dimension": 768, + "context_length": 8192, + }, + model_type=ModelType.embedding, + ), + ] + for m in response.models: + # kill embedding models since we don't know dimensions for them + if m.details.family in ["bert"]: continue - - models = [] - for m in response.models: - model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - if model_type == ModelType.embedding: - continue - models.append( - Model( - identifier=m.model, - provider_resource_id=m.model, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={}, + model_type=ModelType.llm, ) - await self.model_store.update_registered_llm_models(provider_id, models) - logger.debug(f"ollama refreshed model list ({len(models)} models)") - - await asyncio.sleep(self.config.refresh_models_interval) + ) + return models async def health(self) -> HealthResponse: """ @@ -190,12 +194,7 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - if hasattr(self, "_refresh_task") and not self._refresh_task.done(): - logger.debug("ollama cancelling background refresh task") - self._refresh_task.cancel() - - self._client = None - self._openai_client = None + self._clients.clear() async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index f166e4277..211be7efe 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class TogetherImplConfig(BaseModel): +class TogetherImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.together.xyz/v1", description="The URL for the Together AI server", diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e1eb934c5..46094c146 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index ee72f974a..a5bf0e4bc 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel): default=False, description="Whether to refresh models periodically", ) - refresh_models_interval: int = Field( - default=300, - description="Interval in seconds to refresh models", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8bdba1e88..621658a48 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import json from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None - _refresh_task: asyncio.Task | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) @@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - if not self.config.url: - # intentionally don't raise an error here, we want to allow the provider to be "dormant" - # or available in distributions like "starter" without causing a ruckus - return + pass - if self.config.refresh_models: - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - import traceback - - if task.cancelled(): - log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - # print the stack trace for the exception - exc = task.exception() - log.error(f"vLLM background refresh task died: {exc}") - traceback.print_exception(exc) - else: - log.error("vLLM background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - provider_id = self.__provider_id__ - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: self._lazy_initialize_client() assert self.client is not None # mypy - while True: - try: - models = [] - async for m in self.client.models.list(): - model_type = ModelType.llm # unclear how to determine embedding vs. llm models - models.append( - Model( - identifier=m.id, - provider_resource_id=m.id, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) - ) - await self.model_store.update_registered_llm_models(provider_id, models) - log.debug(f"vLLM refreshed model list ({len(models)} models)") - except Exception as e: - log.error(f"vLLM background refresh task failed: {e}") - await asyncio.sleep(self.config.refresh_models_interval) + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=self.__provider_id__, + metadata={}, + model_type=model_type, + ) + ) + return models async def shutdown(self) -> None: - if self._refresh_task: - self._refresh_task.cancel() - self._refresh_task = None + pass async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 651d58e2a..bceeaf198 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import ( logger = get_logger(name=__name__, category="core") +class RemoteInferenceProviderConfig(BaseModel): + allowed_models: list[str] | None = Field( + default=None, + description="List of models that should be registered with the model registry. If None, all models are allowed.", + ) + + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. class ProviderModelEntry(BaseModel): @@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, model_entries: list[ProviderModelEntry]): + __provider_id__: str + + def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): + self.allowed_models = allowed_models self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model + async def list_models(self) -> list[Model] | None: + models = [] + for entry in self.model_entries: + ids = [entry.provider_model_id] + entry.aliases + for id in ids: + if self.allowed_models and id not in self.allowed_models: + continue + models.append( + Model( + model_id=id, + provider_resource_id=entry.provider_model_id, + model_type=ModelType.llm, + metadata=entry.metadata, + provider_id=self.__provider_id__, + ) + ) + return models + + async def should_refresh_models(self) -> bool: + return False + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 12b05ebff..c1b57cb4f 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.distribution.datatypes import RegistryEntrySource from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -45,6 +46,30 @@ class InferenceImpl(Impl): async def unregister_model(self, model_id: str): return model_id + async def should_refresh_models(self): + return False + + async def list_models(self): + return [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + + async def shutdown(self): + pass + class SafetyImpl(Impl): def __init__(self): @@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): raise AssertionError("Should have raised ValueError for non-existent model") except ValueError as e: assert "not found" in str(e) + + +async def test_models_source_tracking_default(cached_disk_dist_registry): + """Test that models registered via register_model get default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model via register_model (should get default source) + await table.register_model(model_id="user-model", provider_id="test_provider") + + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.source == RegistryEntrySource.via_register_api + assert model.identifier == "test_provider/user-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_tracking_provider(cached_disk_dist_registry): + """Test that models registered via update_registered_models get provider source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Simulate provider refresh by calling update_registered_models + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + models = await table.list_models() + assert len(models.data) == 2 + + # All models should have provider source + for model in models.data: + assert model.source == RegistryEntrySource.listed_from_provider + assert model.provider_id == "test_provider" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): + """Test that provider refresh preserves user-registered models with default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # First register a user model with same provider_resource_id as provider will later provide + await table.register_model( + model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + ) + + # Verify user model is registered with default source + models = await table.list_models() + assert len(models.data) == 1 + user_model = models.data[0] + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.identifier == "my-custom-alias" + assert user_model.provider_resource_id == "provider-model-1" + + # Now simulate provider refresh + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="different-model", + provider_resource_id="different-model", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + # Verify user model with alias is preserved, but provider added new model + models = await table.list_models() + assert len(models.data) == 2 + + # Find the user model and provider model + user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + provider_model = next((m for m in models.data if m.identifier == "different-model"), None) + + assert user_model is not None + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.provider_resource_id == "provider-model-1" + + assert provider_model is not None + assert provider_model.source == RegistryEntrySource.listed_from_provider + assert provider_model.provider_resource_id == "different-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): + """Test that provider refresh removes old provider models but keeps default ones.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a user model + await table.register_model(model_id="user-model", provider_id="test_provider") + + # Add some provider models + provider_models_v1 = [ + Model( + identifier="provider-model-old", + provider_resource_id="provider-model-old", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v1) + + # Verify we have both user and provider models + models = await table.list_models() + assert len(models.data) == 2 + + # Now update with new provider models (should remove old provider models) + provider_models_v2 = [ + Model( + identifier="provider-model-new", + provider_resource_id="provider-model-new", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v2) + + # Should have user model + new provider model, old provider model gone + models = await table.list_models() + assert len(models.data) == 2 + + identifiers = {m.identifier for m in models.data} + assert "test_provider/user-model" in identifiers # User model preserved + assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) + assert "provider-model-old" not in identifiers # Old provider model removed + + # Verify sources are correct + user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) + provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) + + assert user_model.source == RegistryEntrySource.via_register_api + assert provider_model.source == RegistryEntrySource.listed_from_provider + + # Cleanup + await table.shutdown()