feat(vllm): periodically refresh models (#2823)

Just like #2805 but for vLLM.

We also make VLLM_URL env variable optional (not required) -- if not
specified, the provider silently sits idle and yells eventually if
someone tries to call a completion on it. This is done so as to allow
this provider to be present in the `starter` distribution.

## Test Plan

Set up vLLM, copy the starter template and set `{ refresh_models: true,
refresh_models_interval: 10 }` for the vllm provider and then run:

```
ENABLE_VLLM=vllm VLLM_URL=http://localhost:8000/v1 \
  uv run llama stack run --image-type venv /tmp/starter.yaml
```

Verify that `llama-stack-client models list` brings up the model
correctly from vLLM.
This commit is contained in:
Ashwin Bharambe 2025-07-18 15:53:09 -07:00 committed by GitHub
parent ade075152e
commit 199f859eec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 98 additions and 14 deletions

View file

@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.VLLM_URL} url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_models( async def update_registered_llm_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],

View file

@ -81,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_models( async def update_registered_llm_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
@ -92,12 +92,16 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of
model_ids = {} model_ids = {}
for model in existing_models: for model in existing_models:
if model.provider_id == provider_id: # we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}") logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model) await self.unregister_object(model)
for model in models: for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids: if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id] model.identifier = model_ids[model.provider_resource_id]

View file

@ -159,18 +159,18 @@ class OllamaInferenceAdapter(
models = [] models = []
for m in response.models: for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :( if model_type == ModelType.embedding:
# we should likely add a hard-coded mapping of model name to embedding dimension continue
models.append( models.append(
Model( Model(
identifier=m.model, identifier=m.model,
provider_resource_id=m.model, provider_resource_id=m.model,
provider_id=provider_id, provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, metadata={},
model_type=model_type, model_type=model_type,
) )
) )
await self.model_store.update_registered_models(provider_id, models) await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)") logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval) await asyncio.sleep(self.config.refresh_models_interval)

View file

@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=True, default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
) )
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod
@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.VLLM_URL}", url: str = "${env.VLLM_URL:=}",
**kwargs, **kwargs,
): ):
return { return {

View file

@ -3,8 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ModelStore,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData, OpenAIEmbeddingData,
@ -54,6 +55,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():
@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config self.config = config
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized _ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None: if self.client is not None:
return return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}") log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client() self.client = self._create_client()

View file

@ -26,7 +26,7 @@ providers:
- provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_id: ${env.ENABLE_VLLM:=__disabled__}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL} url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}