diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index fcb44c072..23b8f87a2 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | +| `refresh_models` | `` | No | False | refresh and re-register models periodically | +| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 222099064..26de04b68 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... + async def update_registered_models( + self, + provider_id: str, + models: list[Model], + ) -> None: ... + class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 6c51dc2c7..5dc0078d4 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -151,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data + self.loop = asyncio.new_event_loop() + def initialize(self): if in_notebook(): import nest_asyncio @@ -159,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return asyncio.run(self.async_client.initialize()) + return self.loop.run_until_complete(self.async_client.initialize()) def _remove_root_logger_handlers(self): """ @@ -172,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): logger.info(f"Removed handler {handler.__class__.__name__} from root logger") def request(self, *args, **kwargs): - # NOTE: We are using AsyncLlamaStackClient under the hood - # A new event loop is needed to convert the AsyncStream - # from async client into SyncStream return type for streaming - loop = asyncio.new_event_loop() + loop = self.loop asyncio.set_event_loop(loop) if kwargs.get("stream"): @@ -192,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return sync_generator() else: @@ -202,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return result diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index c6a10ea9b..90f8afa1c 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -80,3 +80,34 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if existing_model is None: raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) + + async def update_registered_models( + self, + provider_id: str, + models: list[Model], + ) -> None: + existing_models = await self.get_all_with_type("model") + + # we may have an alias for the model registered by the user (or during initialization + # from run.yaml) that we need to keep track of + model_ids = {} + for model in existing_models: + if model.provider_id == provider_id: + model_ids[model.provider_resource_id] = model.identifier + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) + + for model in models: + if model.provider_resource_id in model_ids: + model.identifier = model_ids[model.provider_resource_id] + + logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") + await self.register_object( + ModelWithOwner( + identifier=model.identifier, + provider_resource_id=model.provider_resource_id, + provider_id=provider_id, + metadata=model.metadata, + model_type=model.model_type, + ) + ) diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index 0145810a8..ae261f47c 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,13 +6,15 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL + refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") + refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 010e346bd..a1f7743d5 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import asyncio import base64 import uuid from collections.abc import AsyncGenerator, AsyncIterator @@ -91,23 +92,88 @@ class OllamaInferenceAdapter( InferenceProvider, ModelsProtocolPrivate, ): + # automatically set by the resolver when instantiating the provider + __provider_id__: str + def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) - self.url = config.url + self.config = config + self._client = None + self._openai_client = None @property def client(self) -> AsyncClient: - return AsyncClient(host=self.url) + if self._client is None: + self._client = AsyncClient(host=self.config.url) + return self._client @property def openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") + if self._openai_client is None: + self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama") + return self._openai_client async def initialize(self) -> None: - logger.debug(f"checking connectivity to Ollama at `{self.url}`...") + logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") health_response = await self.health() if health_response["status"] == HealthStatus.ERROR: - raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + logger.warning( + "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" + ) + + if self.config.refresh_models: + logger.debug("ollama starting background model refresh task") + self._refresh_task = asyncio.create_task(self._refresh_models()) + + def cb(task): + if task.cancelled(): + import traceback + + logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") + elif task.exception(): + logger.error(f"ollama background refresh task died: {task.exception()}") + else: + logger.error("ollama background refresh task completed unexpectedly") + + self._refresh_task.add_done_callback(cb) + + async def _refresh_models(self) -> None: + # Wait for model store to be available (with timeout) + waited_time = 0 + while not self.model_store and waited_time < 60: + await asyncio.sleep(1) + waited_time += 1 + + if not self.model_store: + raise ValueError("Model store not set after waiting 60 seconds") + + provider_id = self.__provider_id__ + while True: + try: + response = await self.client.list() + except Exception as e: + logger.warning(f"Failed to list models: {str(e)}") + await asyncio.sleep(self.config.refresh_models_interval) + continue + + models = [] + for m in response.models: + model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm + # unfortunately, ollama does not provide embedding dimension in the model list :( + # we should likely add a hard-coded mapping of model name to embedding dimension + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, + model_type=model_type, + ) + ) + await self.model_store.update_registered_models(provider_id, models) + logger.debug(f"ollama refreshed model list ({len(models)} models)") + + await asyncio.sleep(self.config.refresh_models_interval) async def health(self) -> HealthResponse: """ @@ -124,7 +190,12 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - pass + if hasattr(self, "_refresh_task") and not self._refresh_task.done(): + logger.debug("ollama cancelling background refresh task") + self._refresh_task.cancel() + + self._client = None + self._openai_client = None async def unregister_model(self, model_id: str) -> None: pass @@ -354,8 +425,6 @@ class OllamaInferenceAdapter( raise ValueError("Model provider_resource_id cannot be None") if model.model_type == ModelType.embedding: - logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") - # TODO: you should pull here only if the model is not found in a list response = await self.client.list() if model.provider_resource_id not in [m.model for m in response.models]: await self.client.pull(model.provider_resource_id)