mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 11:47:00 +00:00
feat(ollama): periodically refresh models (#2805)
For self-hosted providers like Ollama (or vLLM), the backing server is running a set of models. That server should be treated as the source of truth and the Stack registry should just be a cache for those models. Of course, in production environments, you may not want this (because you know what model you are running statically) hence there's a config boolean to control this behavior. _This is part of a series of PRs aimed at removing the requirement of needing to set `INFERENCE_MODEL` env variables for running Llama Stack server._ ## Test Plan Copy and modify the starter.yaml template / config and enable `refresh_models: true, refresh_models_interval: 10` for the ollama provider. Then, run: ``` LLAMA_STACK_LOGGING=all=debug \ ENABLE_OLLAMA=ollama uv run llama stack run --image-type venv /tmp/starter.yaml ``` See a gargantuan amount of logs, but verify that the provider is periodically refreshing models. Stop and prune a model from ollama server, restart the server. Verify that the model goes away when I call `uv run llama-stack-client models list`
This commit is contained in:
parent
6d55f2f137
commit
68a2dfbad7
6 changed files with 123 additions and 16 deletions
|
@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||||
|
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
|
||||||
|
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel):
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
async def update_registered_models(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
models: list[Model],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class TextTruncation(Enum):
|
class TextTruncation(Enum):
|
||||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||||
|
|
|
@ -151,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
if in_notebook():
|
if in_notebook():
|
||||||
import nest_asyncio
|
import nest_asyncio
|
||||||
|
@ -159,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
if not self.skip_logger_removal:
|
if not self.skip_logger_removal:
|
||||||
self._remove_root_logger_handlers()
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
return asyncio.run(self.async_client.initialize())
|
return self.loop.run_until_complete(self.async_client.initialize())
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def _remove_root_logger_handlers(self):
|
||||||
"""
|
"""
|
||||||
|
@ -172,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
loop = self.loop
|
||||||
# A new event loop is needed to convert the AsyncStream
|
|
||||||
# from async client into SyncStream return type for streaming
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
|
@ -192,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
pending = asyncio.all_tasks(loop)
|
pending = asyncio.all_tasks(loop)
|
||||||
if pending:
|
if pending:
|
||||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||||
loop.close()
|
|
||||||
|
|
||||||
return sync_generator()
|
return sync_generator()
|
||||||
else:
|
else:
|
||||||
|
@ -202,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
pending = asyncio.all_tasks(loop)
|
pending = asyncio.all_tasks(loop)
|
||||||
if pending:
|
if pending:
|
||||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||||
loop.close()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -80,3 +80,34 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if existing_model is None:
|
if existing_model is None:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ValueError(f"Model {model_id} not found")
|
||||||
await self.unregister_object(existing_model)
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
|
async def update_registered_models(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
models: list[Model],
|
||||||
|
) -> None:
|
||||||
|
existing_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# we may have an alias for the model registered by the user (or during initialization
|
||||||
|
# from run.yaml) that we need to keep track of
|
||||||
|
model_ids = {}
|
||||||
|
for model in existing_models:
|
||||||
|
if model.provider_id == provider_id:
|
||||||
|
model_ids[model.provider_resource_id] = model.identifier
|
||||||
|
logger.debug(f"unregistering model {model.identifier}")
|
||||||
|
await self.unregister_object(model)
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
if model.provider_resource_id in model_ids:
|
||||||
|
model.identifier = model_ids[model.provider_resource_id]
|
||||||
|
|
||||||
|
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||||
|
await self.register_object(
|
||||||
|
ModelWithOwner(
|
||||||
|
identifier=model.identifier,
|
||||||
|
provider_resource_id=model.provider_resource_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=model.metadata,
|
||||||
|
model_type=model.model_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -6,13 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
class OllamaImplConfig(BaseModel):
|
class OllamaImplConfig(BaseModel):
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
url: str = DEFAULT_OLLAMA_URL
|
||||||
|
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
||||||
|
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
@ -91,23 +92,88 @@ class OllamaInferenceAdapter(
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
|
# automatically set by the resolver when instantiating the provider
|
||||||
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, config: OllamaImplConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||||
self.url = config.url
|
self.config = config
|
||||||
|
self._client = None
|
||||||
|
self._openai_client = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
return AsyncClient(host=self.url)
|
if self._client is None:
|
||||||
|
self._client = AsyncClient(host=self.config.url)
|
||||||
|
return self._client
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def openai_client(self) -> AsyncOpenAI:
|
def openai_client(self) -> AsyncOpenAI:
|
||||||
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
if self._openai_client is None:
|
||||||
|
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
|
||||||
|
return self._openai_client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
|
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||||
health_response = await self.health()
|
health_response = await self.health()
|
||||||
if health_response["status"] == HealthStatus.ERROR:
|
if health_response["status"] == HealthStatus.ERROR:
|
||||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
logger.warning(
|
||||||
|
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.refresh_models:
|
||||||
|
logger.debug("ollama starting background model refresh task")
|
||||||
|
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||||
|
|
||||||
|
def cb(task):
|
||||||
|
if task.cancelled():
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||||
|
elif task.exception():
|
||||||
|
logger.error(f"ollama background refresh task died: {task.exception()}")
|
||||||
|
else:
|
||||||
|
logger.error("ollama background refresh task completed unexpectedly")
|
||||||
|
|
||||||
|
self._refresh_task.add_done_callback(cb)
|
||||||
|
|
||||||
|
async def _refresh_models(self) -> None:
|
||||||
|
# Wait for model store to be available (with timeout)
|
||||||
|
waited_time = 0
|
||||||
|
while not self.model_store and waited_time < 60:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
waited_time += 1
|
||||||
|
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set after waiting 60 seconds")
|
||||||
|
|
||||||
|
provider_id = self.__provider_id__
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = await self.client.list()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to list models: {str(e)}")
|
||||||
|
await asyncio.sleep(self.config.refresh_models_interval)
|
||||||
|
continue
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for m in response.models:
|
||||||
|
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||||
|
# unfortunately, ollama does not provide embedding dimension in the model list :(
|
||||||
|
# we should likely add a hard-coded mapping of model name to embedding dimension
|
||||||
|
models.append(
|
||||||
|
Model(
|
||||||
|
identifier=m.model,
|
||||||
|
provider_resource_id=m.model,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.model_store.update_registered_models(provider_id, models)
|
||||||
|
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
||||||
|
|
||||||
|
await asyncio.sleep(self.config.refresh_models_interval)
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -124,7 +190,12 @@ class OllamaInferenceAdapter(
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
||||||
|
logger.debug("ollama cancelling background refresh task")
|
||||||
|
self._refresh_task.cancel()
|
||||||
|
|
||||||
|
self._client = None
|
||||||
|
self._openai_client = None
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -354,8 +425,6 @@ class OllamaInferenceAdapter(
|
||||||
raise ValueError("Model provider_resource_id cannot be None")
|
raise ValueError("Model provider_resource_id cannot be None")
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
|
||||||
# TODO: you should pull here only if the model is not found in a list
|
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||||
await self.client.pull(model.provider_resource_id)
|
await self.client.pull(model.provider_resource_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue