Add dynamic authentication token forwarding support for vLLM provider

This enables per-request authentication tokens for vLLM providers, supporting use cases like RAG operations where different requests may need different authentication tokens. The implementation follows the same pattern as other providers like Together AI, Fireworks, and Passthrough.

- Add LiteLLMOpenAIMixin that manages the vllm_api_token properly

Usage:

- Static: VLLM_API_TOKEN env var or config.api_token
- Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token
All existing functionality is preserved while adding new dynamic capabilities.

Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
Akram Ben Aissi 2025-09-12 20:21:53 +02:00
parent b6cb817897
commit 5e74bc7fcf
4 changed files with 153 additions and 12 deletions

View file

@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
),

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import VLLMInferenceAdapterConfig
class VLLMProviderDataValidator(BaseModel):
vllm_api_token: str | None = None
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
@ -55,6 +55,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -62,6 +63,7 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
@ -281,15 +283,32 @@ async def _process_vllm_chat_completion_stream_response(
yield c
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""Get the base URL, falling back to the api_base from LiteLLMOpenAIMixin or config."""
url = self.api_base or self.config.url
if not url:
raise ValueError("No base URL configured")
return url
async def initialize(self) -> None:
if not self.config.url:
raise ValueError(
@ -325,10 +344,18 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Only performs the test when a static API key is provided.
Returns:
HealthResponse: A dictionary containing the health status.
"""
# Get the default value from the field definition
default_api_token = self.config.__class__.model_fields["api_token"].default
# Only perform health check if static API key is provided
if not self.config.api_token or self.config.api_token == default_api_token:
return HealthResponse(status=HealthStatus.OK)
try:
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
@ -340,16 +367,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_api_key(self):
return self.config.api_token
def get_base_url(self):
return self.config.url
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion(
async def completion( # type: ignore[override]
self,
model_id: str,
content: InterleavedContent,
@ -411,13 +432,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
@ -431,9 +453,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)