feat(vllm): periodically refresh models (#2823)

Just like #2805 but for vLLM.

We also make VLLM_URL env variable optional (not required) -- if not
specified, the provider silently sits idle and yells eventually if
someone tries to call a completion on it. This is done so as to allow
this provider to be present in the `starter` distribution.

## Test Plan

Set up vLLM, copy the starter template and set `{ refresh_models: true,
refresh_models_interval: 10 }` for the vllm provider and then run:

```
ENABLE_VLLM=vllm VLLM_URL=http://localhost:8000/v1 \
  uv run llama stack run --image-type venv /tmp/starter.yaml
```

Verify that `llama-stack-client models list` brings up the model
correctly from vLLM.
This commit is contained in:
Ashwin Bharambe 2025-07-18 15:53:09 -07:00 committed by GitHub
parent ade075152e
commit 199f859eec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 98 additions and 14 deletions

View file

@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration
```yaml
url: ${env.VLLM_URL}
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
async def update_registered_models(
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],

View file

@ -81,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_models(
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
@ -92,12 +92,16 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id == provider_id:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]

View file

@ -159,18 +159,18 @@ class OllamaInferenceAdapter(
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :(
# we should likely add a hard-coded mapping of model name to embedding dimension
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_models(provider_id, models)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)

View file

@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify")
@classmethod
@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL}",
url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {

View file

@ -3,8 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ModelStore,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
@ -54,6 +55,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
def build_hf_repo_model_entries():
@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.client = None
async def initialize(self) -> None:
pass
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None:
pass
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None:
pass
@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status.
"""
try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None:
return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()

View file

@ -26,7 +26,7 @@ providers:
- provider_id: ${env.ENABLE_VLLM:=__disabled__}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL}
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}