feat(ollama): periodically refresh models (#2805)

For self-hosted providers like Ollama (or vLLM), the backing server is
running a set of models. That server should be treated as the source of
truth and the Stack registry should just be a cache for those models. Of
course, in production environments, you may not want this (because you
know what model you are running statically) hence there's a config
boolean to control this behavior.

_This is part of a series of PRs aimed at removing the requirement of
needing to set `INFERENCE_MODEL` env variables for running Llama Stack
server._

## Test Plan

Copy and modify the starter.yaml template / config and enable
`refresh_models: true, refresh_models_interval: 10` for the ollama
provider. Then, run:

```
LLAMA_STACK_LOGGING=all=debug \
  ENABLE_OLLAMA=ollama uv run llama stack run --image-type venv /tmp/starter.yaml
```

See a gargantuan amount of logs, but verify that the provider is
periodically refreshing models. Stop and prune a model from ollama
server, restart the server. Verify that the model goes away when I call
`uv run llama-stack-client models list`
This commit is contained in:
Ashwin Bharambe 2025-07-18 12:20:36 -07:00 committed by GitHub
parent 6d55f2f137
commit 68a2dfbad7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 123 additions and 16 deletions

View file

@ -6,13 +6,15 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
@ -91,23 +92,88 @@ class OllamaInferenceAdapter(
InferenceProvider,
ModelsProtocolPrivate,
):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = config.url
self.config = config
self._client = None
self._openai_client = None
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
@property
def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
if self._openai_client is None:
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
return self._openai_client
async def initialize(self) -> None:
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
logger.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :(
# we should likely add a hard-coded mapping of model name to embedding dimension
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
model_type=model_type,
)
)
await self.model_store.update_registered_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse:
"""
@ -124,7 +190,12 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
pass
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
async def unregister_model(self, model_id: str) -> None:
pass
@ -354,8 +425,6 @@ class OllamaInferenceAdapter(
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
# TODO: you should pull here only if the model is not found in a list
response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)