feat: Add dynamic authentication token forwarding support for vLLM (#3388)

# What does this PR do?


*Add dynamic authentication token forwarding support for vLLM provider*

This enables per-request authentication tokens for vLLM providers,
supporting use cases like RAG operations where different requests may
need different authentication tokens. The implementation follows the
same pattern as other providers like Together AI, Fireworks, and
Passthrough.

- Add LiteLLMOpenAIMixin that manages the vllm_api_token properly

Usage:

- Static: VLLM_API_TOKEN env var or config.api_token
- Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token
All existing functionality is preserved while adding new dynamic
capabilities.


<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

```
curl -X POST "http://localhost:8000/v1/chat/completions" -H "Authorization: Bearer my-dynamic-token" \
  -H "X-LlamaStack-Provider-Data: {\"vllm_api_token\": \"Bearer my-dynamic-token\", \"vllm_url\": \"http://dynamic-server:8000\"}" \
  -H "Content-Type: application/json" \
  -d '{"model": "llama-3.1-8b", "messages": [{"role": "user", "content": "Hello!"}]}'
  
```

---------

Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
Akram Ben Aissi 2025-09-18 10:13:55 +01:00 committed by GitHub
parent 42c23b45f6
commit 4842145202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 219 additions and 48 deletions

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -62,6 +64,7 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
@ -281,15 +284,31 @@ async def _process_vllm_chat_completion_stream_response(
yield c
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""Get the base URL from config."""
if not self.config.url:
raise ValueError("No base URL configured")
return self.config.url
async def initialize(self) -> None:
if not self.config.url:
raise ValueError(
@ -297,6 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
@ -325,13 +345,19 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Uses the unauthenticated /health endpoint.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")
async with httpx.AsyncClient() as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -340,16 +366,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_api_key(self):
return self.config.api_token
def get_base_url(self):
return self.config.url
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion(
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self,
model_id: str,
content: InterleavedContent,
@ -411,13 +431,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
@ -431,9 +452,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
@ -445,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
@ -453,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
stream = await self.client.completions.create(**params)