diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 6c725fb41..5291199a4 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers. | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | +| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration ```yaml -url: ${env.VLLM_URL} +url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 26de04b68..b2bb8a8e6 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_models( + async def update_registered_llm_models( self, provider_id: str, models: list[Model], diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 90f8afa1c..9a9db7257 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -81,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_models( + async def update_registered_llm_models( self, provider_id: str, models: list[Model], @@ -92,12 +92,16 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - if model.provider_id == provider_id: + # we leave embeddings models alone because often we don't get metadata + # (embedding dimension, etc.) from the provider + if model.provider_id == provider_id and model.model_type == ModelType.llm: model_ids[model.provider_resource_id] = model.identifier logger.debug(f"unregistering model {model.identifier}") await self.unregister_object(model) for model in models: + if model.model_type != ModelType.llm: + continue if model.provider_resource_id in model_ids: model.identifier = model_ids[model.provider_resource_id] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a1f7743d5..76d789d07 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -159,18 +159,18 @@ class OllamaInferenceAdapter( models = [] for m in response.models: model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - # unfortunately, ollama does not provide embedding dimension in the model list :( - # we should likely add a hard-coded mapping of model name to embedding dimension + if model_type == ModelType.embedding: + continue models.append( Model( identifier=m.model, provider_resource_id=m.model, provider_id=provider_id, - metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, + metadata={}, model_type=model_type, ) ) - await self.model_store.update_registered_models(provider_id, models) + await self.model_store.update_registered_llm_models(provider_id, models) logger.debug(f"ollama refreshed model list ({len(models)} models)") await asyncio.sleep(self.config.refresh_models_interval) diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index e11efa7f0..ee72f974a 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel): default=True, description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", ) + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) + refresh_models_interval: int = Field( + default=300, + description="Interval in seconds to refresh models", + ) @field_validator("tls_verify") @classmethod @@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel): @classmethod def sample_run_config( cls, - url: str = "${env.VLLM_URL}", + url: str = "${env.VLLM_URL:=}", **kwargs, ): return { diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d1455acaa..8bdba1e88 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,8 +3,8 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import json -import logging from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + ModelStore, OpenAIChatCompletion, OpenAICompletion, OpenAIEmbeddingData, @@ -54,6 +55,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ( @@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") def build_hf_repo_model_entries(): @@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response( class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): + # automatically set by the resolver when instantiating the provider + __provider_id__: str + model_store: ModelStore | None = None + _refresh_task: asyncio.Task | None = None + def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config self.client = None async def initialize(self) -> None: - pass + if not self.config.url: + # intentionally don't raise an error here, we want to allow the provider to be "dormant" + # or available in distributions like "starter" without causing a ruckus + return + + if self.config.refresh_models: + self._refresh_task = asyncio.create_task(self._refresh_models()) + + def cb(task): + import traceback + + if task.cancelled(): + log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") + elif task.exception(): + # print the stack trace for the exception + exc = task.exception() + log.error(f"vLLM background refresh task died: {exc}") + traceback.print_exception(exc) + else: + log.error("vLLM background refresh task completed unexpectedly") + + self._refresh_task.add_done_callback(cb) + + async def _refresh_models(self) -> None: + provider_id = self.__provider_id__ + waited_time = 0 + while not self.model_store and waited_time < 60: + await asyncio.sleep(1) + waited_time += 1 + + if not self.model_store: + raise ValueError("Model store not set after waiting 60 seconds") + + self._lazy_initialize_client() + assert self.client is not None # mypy + while True: + try: + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=provider_id, + metadata={}, + model_type=model_type, + ) + ) + await self.model_store.update_registered_llm_models(provider_id, models) + log.debug(f"vLLM refreshed model list ({len(models)} models)") + except Exception as e: + log.error(f"vLLM background refresh task failed: {e}") + await asyncio.sleep(self.config.refresh_models_interval) async def shutdown(self) -> None: - pass + if self._refresh_task: + self._refresh_task.cancel() + self._refresh_task = None async def unregister_model(self, model_id: str) -> None: pass @@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): HealthResponse: A dictionary containing the health status. """ try: + if not self.config.url: + return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set") + client = self._create_client() if self.client is None else self.client _ = [m async for m in client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) @@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if self.client is not None: return + if not self.config.url: + raise ValueError( + "You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)" + ) + log.info(f"Initializing vLLM client with base_url={self.config.url}") self.client = self._create_client() diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 27400348a..46573848c 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -26,7 +26,7 @@ providers: - provider_id: ${env.ENABLE_VLLM:=__disabled__} provider_type: remote::vllm config: - url: ${env.VLLM_URL} + url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true}