diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 64196152b..0eb4cf104 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.vllm", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", description="Remote vLLM inference provider for connecting to vLLM servers.", ), ), diff --git a/llama_stack/providers/remote/inference/vllm/__init__.py b/llama_stack/providers/remote/inference/vllm/__init__.py index e4322a6aa..1f196e507 100644 --- a/llama_stack/providers/remote/inference/vllm/__init__.py +++ b/llama_stack/providers/remote/inference/vllm/__init__.py @@ -4,9 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .config import VLLMInferenceAdapterConfig +class VLLMProviderDataValidator(BaseModel): + vllm_api_token: str | None = None + + async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): from .vllm import VLLMInferenceAdapter diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 77f5d82af..d1ed6365a 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import Any import httpx @@ -55,6 +55,7 @@ from llama_stack.providers.datatypes import ( HealthStatus, ModelsProtocolPrivate, ) +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, @@ -62,6 +63,7 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_message_to_openai_dict, + convert_openai_chat_completion_stream, convert_tool_call, get_sampling_options, process_chat_completion_stream_response, @@ -281,15 +283,32 @@ async def _process_vllm_chat_completion_stream_response( yield c -class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): +class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + LiteLLMOpenAIMixin.__init__( + self, + build_hf_repo_model_entries(), + litellm_provider_name="vllm", + api_key_from_config=config.api_token, + provider_data_api_key_field="vllm_api_token", + openai_compat_api_base=config.url, + ) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """Get the base URL, falling back to the api_base from LiteLLMOpenAIMixin or config.""" + url = self.api_base or self.config.url + if not url: + raise ValueError("No base URL configured") + return url + async def initialize(self) -> None: if not self.config.url: raise ValueError( @@ -325,10 +344,18 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): Performs a health check by verifying connectivity to the remote vLLM server. This method is used by the Provider API to verify that the service is running correctly. + Only performs the test when a static API key is provided. Returns: HealthResponse: A dictionary containing the health status. """ + # Get the default value from the field definition + default_api_token = self.config.__class__.model_fields["api_token"].default + + # Only perform health check if static API key is provided + if not self.config.api_token or self.config.api_token == default_api_token: + return HealthResponse(status=HealthStatus.OK) + try: _ = [m async for m in self.client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) @@ -340,16 +367,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) - def get_api_key(self): - return self.config.api_token - - def get_base_url(self): - return self.config.url - def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def completion( + async def completion( # type: ignore[override] self, model_id: str, content: InterleavedContent, @@ -411,13 +432,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): tool_config=tool_config, ) if stream: - return self._stream_chat_completion(request, self.client) + return self._stream_chat_completion_with_client(request, self.client) else: return await self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> ChatCompletionResponse: + assert self.client is not None params = await self._get_params(request) r = await client.chat.completions.create(**params) choice = r.choices[0] @@ -431,9 +453,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion( + async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + # This method is called from LiteLLMOpenAIMixin.chat_completion + # The response parameter contains the litellm response + # We need to convert it to our format + async def _stream_generator(): + async for chunk in response: + yield chunk + + async for chunk in convert_openai_chat_completion_stream( + _stream_generator(), enable_incremental_tool_calls=True + ): + yield chunk + + async def _stream_chat_completion_with_client( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """Helper method for streaming with explicit client parameter.""" + assert self.client is not None params = await self._get_params(request) stream = await client.chat.completions.create(**params) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 61b16b5d1..96a57c3c8 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -66,11 +66,15 @@ def mock_openai_models_list(): yield mock_list -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") async def vllm_inference_adapter(): config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") inference_adapter = VLLMInferenceAdapter(config) inference_adapter.model_store = AsyncMock() + # Mock the __provider_spec__ attribute that would normally be set by the resolver + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_type = "vllm-inference" + inference_adapter.__provider_spec__.provider_data_validator = MagicMock() await inference_adapter.initialize() return inference_adapter @@ -120,6 +124,10 @@ async def test_tool_call_response(vllm_inference_adapter): mock_client.chat.completions.create = AsyncMock() mock_create_client.return_value = mock_client + # Mock the model to return a proper provider_resource_id + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + messages = [ SystemMessage(content="You are a helpful assistant"), UserMessage(content="How many?"), @@ -558,6 +566,9 @@ async def test_health_status_success(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status OK, only when the connection to the vLLM server is successful. """ + # Set a non-default API token to enable health check + vllm_inference_adapter.config.api_token = "real-api-key" + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() @@ -589,6 +600,9 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ + # Set a non-default API token to enable health check + vllm_inference_adapter.config.api_token = "real-api-key" + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() @@ -613,6 +627,23 @@ async def test_health_status_failure(vllm_inference_adapter): mock_models.list.assert_called_once() +async def test_health_status_no_static_api_key(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when no static API key is provided. + + This test verifies that the health method returns a HealthResponse with status OK + without performing any connectivity test when no static API key is provided. + """ + # Ensure api_token is the default value (no static API key) + vllm_inference_adapter.config.api_token = "fake" + + # Call the health method + health_response = await vllm_inference_adapter.health() + + # Verify the response + assert health_response["status"] == HealthStatus.OK + + async def test_openai_chat_completion_is_async(vllm_inference_adapter): """ Verify that openai_chat_completion is async and doesn't block the event loop. @@ -656,3 +687,69 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_provider_data_var_context_propagation(vllm_inference_adapter): + """ + Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter. + This ensures that dynamic provider data (like API tokens) can be passed through context. + Note: The base URL is always taken from config.url, not from provider data. + """ + # Mock the AsyncOpenAI class to capture provider data + with ( + patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class, + patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data, + ): + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock provider data to return test data + mock_provider_data = MagicMock() + mock_provider_data.vllm_api_token = "test-token-123" + mock_provider_data.vllm_url = "http://test-server:8000/v1" + mock_get_provider_data.return_value = mock_provider_data + + # Mock the model + mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + try: + # Execute chat completion + await vllm_inference_adapter.chat_completion( + "test-model", + [UserMessage(content="Hello")], + stream=False, + tools=None, + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + # Verify that ALL client calls were made with the correct parameters + calls = mock_openai_class.call_args_list + incorrect_calls = [] + + for i, call in enumerate(calls): + api_key = call[1]["api_key"] + base_url = call[1]["base_url"] + + if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345": + incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url}) + + if incorrect_calls: + error_msg = ( + f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n" + ) + for incorrect_call in incorrect_calls: + error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n" + error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'" + raise AssertionError(error_msg) + + # Ensure at least one call was made + assert len(calls) >= 1, "No AsyncOpenAI client calls were made" + + # Verify that chat completion was called + mock_client.chat.completions.create.assert_called_once() + + finally: + # Clean up context + pass