feat(ollama): periodically refresh models (#2805)

For self-hosted providers like Ollama (or vLLM), the backing server is
running a set of models. That server should be treated as the source of
truth and the Stack registry should just be a cache for those models. Of
course, in production environments, you may not want this (because you
know what model you are running statically) hence there's a config
boolean to control this behavior.

_This is part of a series of PRs aimed at removing the requirement of
needing to set `INFERENCE_MODEL` env variables for running Llama Stack
server._

## Test Plan

Copy and modify the starter.yaml template / config and enable
`refresh_models: true, refresh_models_interval: 10` for the ollama
provider. Then, run:

```
LLAMA_STACK_LOGGING=all=debug \
  ENABLE_OLLAMA=ollama uv run llama stack run --image-type venv /tmp/starter.yaml
```

See a gargantuan amount of logs, but verify that the provider is
periodically refreshing models. Stop and prune a model from ollama
server, restart the server. Verify that the model goes away when I call
`uv run llama-stack-client models list`
This commit is contained in:
Ashwin Bharambe 2025-07-18 12:20:36 -07:00 committed by GitHub
parent 6d55f2f137
commit 68a2dfbad7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 123 additions and 16 deletions

View file

@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum): class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -151,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self): def initialize(self):
if in_notebook(): if in_notebook():
import nest_asyncio import nest_asyncio
@ -159,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal: if not self.skip_logger_removal:
self._remove_root_logger_handlers() self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize()) return self.loop.run_until_complete(self.async_client.initialize())
def _remove_root_logger_handlers(self): def _remove_root_logger_handlers(self):
""" """
@ -172,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
logger.info(f"Removed handler {handler.__class__.__name__} from root logger") logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
# NOTE: We are using AsyncLlamaStackClient under the hood loop = self.loop
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
if kwargs.get("stream"): if kwargs.get("stream"):
@ -192,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop) pending = asyncio.all_tasks(loop)
if pending: if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator() return sync_generator()
else: else:
@ -202,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop) pending = asyncio.all_tasks(loop)
if pending: if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return result return result

View file

@ -80,3 +80,34 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if existing_model is None: if existing_model is None:
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id == provider_id:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
)
)

View file

@ -6,13 +6,15 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
@ -91,23 +92,88 @@ class OllamaInferenceAdapter(
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = config.url self.config = config
self._client = None
self._openai_client = None
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return AsyncClient(host=self.url) if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
@property @property
def openai_client(self) -> AsyncOpenAI: def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") if self._openai_client is None:
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
return self._openai_client
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug(f"checking connectivity to Ollama at `{self.url}`...") logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health() health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR: if health_response["status"] == HealthStatus.ERROR:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") logger.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :(
# we should likely add a hard-coded mapping of model name to embedding dimension
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
model_type=model_type,
)
)
await self.model_store.update_registered_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """
@ -124,7 +190,12 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -354,8 +425,6 @@ class OllamaInferenceAdapter(
raise ValueError("Model provider_resource_id cannot be None") raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
# TODO: you should pull here only if the model is not found in a list
response = await self.client.list() response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]: if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id) await self.client.pull(model.provider_resource_id)