diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index 23b8f87a2..f9f0a7622 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | refresh and re-register models periodically | -| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 5291199a4..172d35873 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b2bb8a8e6..222099064 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_llm_models( - self, - provider_id: str, - models: list[Model], - ) -> None: ... - class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index f2787b308..437db0176 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import time from typing import Any @@ -19,6 +20,47 @@ logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): + listed_providers: set[str] = set() + model_refresh_interval_seconds: int = 300 + + async def initialize(self) -> None: + await super().initialize() + task = asyncio.create_task(self._refresh_models()) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + task.add_done_callback(cb) + + async def _refresh_models(self) -> None: + while True: + for provider_id, provider in self.impls_by_provider_id.items(): + refresh = await provider.should_refresh_models() + if not (refresh or provider_id in self.listed_providers): + continue + + try: + models = await provider.list_models() + except Exception as e: + logger.exception(f"Model refresh failed for provider {provider_id}: {e}") + continue + + self.listed_providers.add(provider_id) + if models is None: + continue + + await self.update_registered_llm_models(provider_id, models) + + await asyncio.sleep(self.model_refresh_interval_seconds) + async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 424380324..005bfbab8 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol): async def unregister_model(self, model_id: str) -> None: ... + # the Stack router will query each provider for their list of models + # if a `refresh_interval_seconds` is provided, this method will be called + # periodically to refresh the list of models + # + # NOTE: each model returned will be registered with the model registry. this means + # a callback to the `register_model()` method will be made. this is duplicative and + # may be removed in the future. + async def list_models(self) -> list[Model] | None: ... + + async def should_refresh_models(self) -> bool: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..88d7a98ec 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return None + async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 890c526f5..0beecd2c4 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -50,6 +50,13 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + # TODO: add all-mini-lm models + return None + async def register_model(self, model: Model) -> Model: return model diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index ae261f47c..ce13f0d83 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") - refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 76d789d07..e9a10d0a8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,6 @@ # the root directory of this source tree. -import asyncio import base64 import uuid from collections.abc import AsyncGenerator, AsyncIterator @@ -121,59 +120,27 @@ class OllamaInferenceAdapter( "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" ) - if self.config.refresh_models: - logger.debug("ollama starting background model refresh task") - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - if task.cancelled(): - import traceback - - logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - logger.error(f"ollama background refresh task died: {task.exception()}") - else: - logger.error("ollama background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - # Wait for model store to be available (with timeout) - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ - while True: - try: - response = await self.client.list() - except Exception as e: - logger.warning(f"Failed to list models: {str(e)}") - await asyncio.sleep(self.config.refresh_models_interval) + response = await self.client.list() + models = [] + for m in response.models: + model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm + if model_type == ModelType.embedding: continue - - models = [] - for m in response.models: - model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - if model_type == ModelType.embedding: - continue - models.append( - Model( - identifier=m.model, - provider_resource_id=m.model, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={}, + model_type=model_type, ) - await self.model_store.update_registered_llm_models(provider_id, models) - logger.debug(f"ollama refreshed model list ({len(models)} models)") - - await asyncio.sleep(self.config.refresh_models_interval) + ) + return models async def health(self) -> HealthResponse: """ @@ -190,10 +157,6 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - if hasattr(self, "_refresh_task") and not self._refresh_task.done(): - logger.debug("ollama cancelling background refresh task") - self._refresh_task.cancel() - self._client = None self._openai_client = None diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index ee72f974a..a5bf0e4bc 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel): default=False, description="Whether to refresh models periodically", ) - refresh_models_interval: int = Field( - default=300, - description="Interval in seconds to refresh models", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8bdba1e88..621658a48 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import json from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None - _refresh_task: asyncio.Task | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) @@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - if not self.config.url: - # intentionally don't raise an error here, we want to allow the provider to be "dormant" - # or available in distributions like "starter" without causing a ruckus - return + pass - if self.config.refresh_models: - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - import traceback - - if task.cancelled(): - log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - # print the stack trace for the exception - exc = task.exception() - log.error(f"vLLM background refresh task died: {exc}") - traceback.print_exception(exc) - else: - log.error("vLLM background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - provider_id = self.__provider_id__ - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: self._lazy_initialize_client() assert self.client is not None # mypy - while True: - try: - models = [] - async for m in self.client.models.list(): - model_type = ModelType.llm # unclear how to determine embedding vs. llm models - models.append( - Model( - identifier=m.id, - provider_resource_id=m.id, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) - ) - await self.model_store.update_registered_llm_models(provider_id, models) - log.debug(f"vLLM refreshed model list ({len(models)} models)") - except Exception as e: - log.error(f"vLLM background refresh task failed: {e}") - await asyncio.sleep(self.config.refresh_models_interval) + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=self.__provider_id__, + metadata={}, + model_type=model_type, + ) + ) + return models async def shutdown(self) -> None: - if self._refresh_task: - self._refresh_task.cancel() - self._refresh_task = None + pass async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 651d58e2a..84265a85a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -65,6 +65,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): + __provider_id__: str + def __init__(self, model_entries: list[ProviderModelEntry]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} @@ -79,6 +81,25 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model + async def list_models(self) -> list[Model] | None: + models = [] + for entry in self.model_entries: + ids = [entry.provider_model_id] + entry.aliases + for id in ids: + models.append( + Model( + model_id=id, + provider_resource_id=entry.provider_model_id, + model_type=ModelType.llm, + metadata=entry.metadata, + provider_id=self.__provider_id__, + ) + ) + return models + + async def should_refresh_models(self) -> bool: + return False + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None)