Add dynamic authentication token forwarding support for vLLM provider

This enables per-request authentication tokens for vLLM providers, supporting use cases like RAG operations where different requests may need different authentication tokens. The implementation follows the same pattern as other providers like Together AI, Fireworks, and Passthrough.

- Add LiteLLMOpenAIMixin that manages the vllm_api_token properly

Usage:

- Static: VLLM_API_TOKEN env var or config.api_token
- Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token
All existing functionality is preserved while adding new dynamic capabilities.

Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
Akram Ben Aissi 2025-09-12 20:21:53 +02:00
parent b6cb817897
commit 5e74bc7fcf
4 changed files with 153 additions and 12 deletions

View file

@ -66,11 +66,15 @@ def mock_openai_models_list():
yield mock_list
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
@ -120,6 +124,10 @@ async def test_tool_call_response(vllm_inference_adapter):
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
@ -558,6 +566,9 @@ async def test_health_status_success(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
"""
# Set a non-default API token to enable health check
vllm_inference_adapter.config.api_token = "real-api-key"
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
@ -589,6 +600,9 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
# Set a non-default API token to enable health check
vllm_inference_adapter.config.api_token = "real-api-key"
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
@ -613,6 +627,23 @@ async def test_health_status_failure(vllm_inference_adapter):
mock_models.list.assert_called_once()
async def test_health_status_no_static_api_key(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when no static API key is provided.
This test verifies that the health method returns a HealthResponse with status OK
without performing any connectivity test when no static API key is provided.
"""
# Ensure api_token is the default value (no static API key)
vllm_inference_adapter.config.api_token = "fake"
# Call the health method
health_response = await vllm_inference_adapter.health()
# Verify the response
assert health_response["status"] == HealthStatus.OK
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
"""
Verify that openai_chat_completion is async and doesn't block the event loop.
@ -656,3 +687,69 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass