From 5e74bc7fcff7a7833c0cc3d731234f98950ff895 Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Fri, 12 Sep 2025 20:21:53 +0200 Subject: [PATCH] Add dynamic authentication token forwarding support for vLLM provider This enables per-request authentication tokens for vLLM providers, supporting use cases like RAG operations where different requests may need different authentication tokens. The implementation follows the same pattern as other providers like Together AI, Fireworks, and Passthrough. - Add LiteLLMOpenAIMixin that manages the vllm_api_token properly Usage: - Static: VLLM_API_TOKEN env var or config.api_token - Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token All existing functionality is preserved while adding new dynamic capabilities. Signed-off-by: Akram Ben Aissi --- llama_stack/providers/registry/inference.py | 1 + .../remote/inference/vllm/__init__.py | 6 ++ .../providers/remote/inference/vllm/vllm.py | 59 ++++++++--- .../providers/inference/test_remote_vllm.py | 99 ++++++++++++++++++- 4 files changed, 153 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 64196152b..0eb4cf104 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.vllm", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", description="Remote vLLM inference provider for connecting to vLLM servers.", ), ), diff --git a/llama_stack/providers/remote/inference/vllm/__init__.py b/llama_stack/providers/remote/inference/vllm/__init__.py index e4322a6aa..1f196e507 100644 --- a/llama_stack/providers/remote/inference/vllm/__init__.py +++ b/llama_stack/providers/remote/inference/vllm/__init__.py @@ -4,9 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .config import VLLMInferenceAdapterConfig +class VLLMProviderDataValidator(BaseModel): + vllm_api_token: str | None = None + + async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): from .vllm import VLLMInferenceAdapter diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 77f5d82af..d1ed6365a 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import Any import httpx @@ -55,6 +55,7 @@ from llama_stack.providers.datatypes import ( HealthStatus, ModelsProtocolPrivate, ) +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, @@ -62,6 +63,7 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_message_to_openai_dict, + convert_openai_chat_completion_stream, convert_tool_call, get_sampling_options, process_chat_completion_stream_response, @@ -281,15 +283,32 @@ async def _process_vllm_chat_completion_stream_response( yield c -class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): +class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + LiteLLMOpenAIMixin.__init__( + self, + build_hf_repo_model_entries(), + litellm_provider_name="vllm", + api_key_from_config=config.api_token, + provider_data_api_key_field="vllm_api_token", + openai_compat_api_base=config.url, + ) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """Get the base URL, falling back to the api_base from LiteLLMOpenAIMixin or config.""" + url = self.api_base or self.config.url + if not url: + raise ValueError("No base URL configured") + return url + async def initialize(self) -> None: if not self.config.url: raise ValueError( @@ -325,10 +344,18 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): Performs a health check by verifying connectivity to the remote vLLM server. This method is used by the Provider API to verify that the service is running correctly. + Only performs the test when a static API key is provided. Returns: HealthResponse: A dictionary containing the health status. """ + # Get the default value from the field definition + default_api_token = self.config.__class__.model_fields["api_token"].default + + # Only perform health check if static API key is provided + if not self.config.api_token or self.config.api_token == default_api_token: + return HealthResponse(status=HealthStatus.OK) + try: _ = [m async for m in self.client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) @@ -340,16 +367,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) - def get_api_key(self): - return self.config.api_token - - def get_base_url(self): - return self.config.url - def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def completion( + async def completion( # type: ignore[override] self, model_id: str, content: InterleavedContent, @@ -411,13 +432,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): tool_config=tool_config, ) if stream: - return self._stream_chat_completion(request, self.client) + return self._stream_chat_completion_with_client(request, self.client) else: return await self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> ChatCompletionResponse: + assert self.client is not None params = await self._get_params(request) r = await client.chat.completions.create(**params) choice = r.choices[0] @@ -431,9 +453,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion( + async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + # This method is called from LiteLLMOpenAIMixin.chat_completion + # The response parameter contains the litellm response + # We need to convert it to our format + async def _stream_generator(): + async for chunk in response: + yield chunk + + async for chunk in convert_openai_chat_completion_stream( + _stream_generator(), enable_incremental_tool_calls=True + ): + yield chunk + + async def _stream_chat_completion_with_client( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """Helper method for streaming with explicit client parameter.""" + assert self.client is not None params = await self._get_params(request) stream = await client.chat.completions.create(**params) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 61b16b5d1..96a57c3c8 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -66,11 +66,15 @@ def mock_openai_models_list(): yield mock_list -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") async def vllm_inference_adapter(): config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") inference_adapter = VLLMInferenceAdapter(config) inference_adapter.model_store = AsyncMock() + # Mock the __provider_spec__ attribute that would normally be set by the resolver + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_type = "vllm-inference" + inference_adapter.__provider_spec__.provider_data_validator = MagicMock() await inference_adapter.initialize() return inference_adapter @@ -120,6 +124,10 @@ async def test_tool_call_response(vllm_inference_adapter): mock_client.chat.completions.create = AsyncMock() mock_create_client.return_value = mock_client + # Mock the model to return a proper provider_resource_id + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + messages = [ SystemMessage(content="You are a helpful assistant"), UserMessage(content="How many?"), @@ -558,6 +566,9 @@ async def test_health_status_success(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status OK, only when the connection to the vLLM server is successful. """ + # Set a non-default API token to enable health check + vllm_inference_adapter.config.api_token = "real-api-key" + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() @@ -589,6 +600,9 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ + # Set a non-default API token to enable health check + vllm_inference_adapter.config.api_token = "real-api-key" + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() @@ -613,6 +627,23 @@ async def test_health_status_failure(vllm_inference_adapter): mock_models.list.assert_called_once() +async def test_health_status_no_static_api_key(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when no static API key is provided. + + This test verifies that the health method returns a HealthResponse with status OK + without performing any connectivity test when no static API key is provided. + """ + # Ensure api_token is the default value (no static API key) + vllm_inference_adapter.config.api_token = "fake" + + # Call the health method + health_response = await vllm_inference_adapter.health() + + # Verify the response + assert health_response["status"] == HealthStatus.OK + + async def test_openai_chat_completion_is_async(vllm_inference_adapter): """ Verify that openai_chat_completion is async and doesn't block the event loop. @@ -656,3 +687,69 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_provider_data_var_context_propagation(vllm_inference_adapter): + """ + Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter. + This ensures that dynamic provider data (like API tokens) can be passed through context. + Note: The base URL is always taken from config.url, not from provider data. + """ + # Mock the AsyncOpenAI class to capture provider data + with ( + patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class, + patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data, + ): + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock provider data to return test data + mock_provider_data = MagicMock() + mock_provider_data.vllm_api_token = "test-token-123" + mock_provider_data.vllm_url = "http://test-server:8000/v1" + mock_get_provider_data.return_value = mock_provider_data + + # Mock the model + mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + try: + # Execute chat completion + await vllm_inference_adapter.chat_completion( + "test-model", + [UserMessage(content="Hello")], + stream=False, + tools=None, + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + # Verify that ALL client calls were made with the correct parameters + calls = mock_openai_class.call_args_list + incorrect_calls = [] + + for i, call in enumerate(calls): + api_key = call[1]["api_key"] + base_url = call[1]["base_url"] + + if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345": + incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url}) + + if incorrect_calls: + error_msg = ( + f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n" + ) + for incorrect_call in incorrect_calls: + error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n" + error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'" + raise AssertionError(error_msg) + + # Ensure at least one call was made + assert len(calls) >= 1, "No AsyncOpenAI client calls were made" + + # Verify that chat completion was called + mock_client.chat.completions.create.assert_called_once() + + finally: + # Clean up context + pass