diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 64196152b..0eb4cf104 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.vllm", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", description="Remote vLLM inference provider for connecting to vLLM servers.", ), ), diff --git a/llama_stack/providers/remote/inference/vllm/__init__.py b/llama_stack/providers/remote/inference/vllm/__init__.py index e4322a6aa..1f196e507 100644 --- a/llama_stack/providers/remote/inference/vllm/__init__.py +++ b/llama_stack/providers/remote/inference/vllm/__init__.py @@ -4,9 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .config import VLLMInferenceAdapterConfig +class VLLMProviderDataValidator(BaseModel): + vllm_api_token: str | None = None + + async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): from .vllm import VLLMInferenceAdapter diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 77f5d82af..b4079c39f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import Any +from urllib.parse import urljoin import httpx from openai import APIConnectionError, AsyncOpenAI @@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import ( HealthStatus, ModelsProtocolPrivate, ) +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, @@ -62,6 +64,7 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_message_to_openai_dict, + convert_openai_chat_completion_stream, convert_tool_call, get_sampling_options, process_chat_completion_stream_response, @@ -281,15 +284,31 @@ async def _process_vllm_chat_completion_stream_response( yield c -class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): +class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + LiteLLMOpenAIMixin.__init__( + self, + build_hf_repo_model_entries(), + litellm_provider_name="vllm", + api_key_from_config=config.api_token, + provider_data_api_key_field="vllm_api_token", + openai_compat_api_base=config.url, + ) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """Get the base URL from config.""" + if not self.config.url: + raise ValueError("No base URL configured") + return self.config.url + async def initialize(self) -> None: if not self.config.url: raise ValueError( @@ -297,6 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): ) async def should_refresh_models(self) -> bool: + # Strictly respecting the refresh_models directive return self.config.refresh_models async def list_models(self) -> list[Model] | None: @@ -325,13 +345,19 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): Performs a health check by verifying connectivity to the remote vLLM server. This method is used by the Provider API to verify that the service is running correctly. + Uses the unauthenticated /health endpoint. Returns: HealthResponse: A dictionary containing the health status. """ try: - _ = [m async for m in self.client.models.list()] # Ensure the client is initialized - return HealthResponse(status=HealthStatus.OK) + base_url = self.get_base_url() + health_url = urljoin(base_url, "health") + + async with httpx.AsyncClient() as client: + response = await client.get(health_url) + response.raise_for_status() + return HealthResponse(status=HealthStatus.OK) except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") @@ -340,16 +366,10 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) - def get_api_key(self): - return self.config.api_token - - def get_base_url(self): - return self.config.url - def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def completion( + async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses. self, model_id: str, content: InterleavedContent, @@ -411,13 +431,14 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): tool_config=tool_config, ) if stream: - return self._stream_chat_completion(request, self.client) + return self._stream_chat_completion_with_client(request, self.client) else: return await self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> ChatCompletionResponse: + assert self.client is not None params = await self._get_params(request) r = await client.chat.completions.create(**params) choice = r.choices[0] @@ -431,9 +452,24 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion( + async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + # This method is called from LiteLLMOpenAIMixin.chat_completion + # The response parameter contains the litellm response + # We need to convert it to our format + async def _stream_generator(): + async for chunk in response: + yield chunk + + async for chunk in convert_openai_chat_completion_stream( + _stream_generator(), enable_incremental_tool_calls=True + ): + yield chunk + + async def _stream_chat_completion_with_client( self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """Helper method for streaming with explicit client parameter.""" + assert self.client is not None params = await self._get_params(request) stream = await client.chat.completions.create(**params) @@ -445,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: - assert self.client is not None + if self.client is None: + raise RuntimeError("Client is not initialized") params = await self._get_params(request) r = await self.client.completions.create(**params) return process_completion_response(r) @@ -453,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): async def _stream_completion( self, request: CompletionRequest ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - assert self.client is not None + if self.client is None: + raise RuntimeError("Client is not initialized") params = await self._get_params(request) stream = await self.client.completions.create(**params) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 61b16b5d1..9545e0cf6 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -66,11 +66,15 @@ def mock_openai_models_list(): yield mock_list -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") async def vllm_inference_adapter(): config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") inference_adapter = VLLMInferenceAdapter(config) inference_adapter.model_store = AsyncMock() + # Mock the __provider_spec__ attribute that would normally be set by the resolver + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_type = "vllm-inference" + inference_adapter.__provider_spec__.provider_data_validator = MagicMock() await inference_adapter.initialize() return inference_adapter @@ -120,6 +124,10 @@ async def test_tool_call_response(vllm_inference_adapter): mock_client.chat.completions.create = AsyncMock() mock_create_client.return_value = mock_client + # Mock the model to return a proper provider_resource_id + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + messages = [ SystemMessage(content="You are a helpful assistant"), UserMessage(content="How many?"), @@ -555,31 +563,29 @@ async def test_health_status_success(vllm_inference_adapter): """ Test the health method of VLLM InferenceAdapter when the connection is successful. - This test verifies that the health method returns a HealthResponse with status OK, only - when the connection to the vLLM server is successful. + This test verifies that the health method returns a HealthResponse with status OK + when the /health endpoint responds successfully. """ - with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: - # Create mock client and models - mock_client = MagicMock() - mock_models = MagicMock() + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock response + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None - # Create a mock async iterator that yields a model when iterated - async def mock_list(): - for model in [MagicMock()]: - yield model - - # Set up the models.list to return our mock async iterator - mock_models.list.return_value = mock_list() - mock_client.models = mock_models - mock_create_client.return_value = mock_client + # Create mock client instance + mock_client_instance = MagicMock() + mock_client_instance.get = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance # Call the health method health_response = await vllm_inference_adapter.health() + # Verify the response assert health_response["status"] == HealthStatus.OK - # Verify that models.list was called - mock_models.list.assert_called_once() + # Verify that the health endpoint was called + mock_client_instance.get.assert_called_once() + call_args = mock_client_instance.get.call_args[0] + assert call_args[0].endswith("/health") async def test_health_status_failure(vllm_inference_adapter): @@ -589,28 +595,42 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ - with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: - # Create mock client and models - mock_client = MagicMock() - mock_models = MagicMock() - - # Create a mock async iterator that raises an exception when iterated - async def mock_list(): - raise Exception("Connection failed") - yield # Unreachable code - - # Set up the models.list to return our mock async iterator - mock_models.list.return_value = mock_list() - mock_client.models = mock_models - mock_create_client.return_value = mock_client + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock client instance that raises an exception + mock_client_instance = MagicMock() + mock_client_instance.get.side_effect = Exception("Connection failed") + mock_client_class.return_value.__aenter__.return_value = mock_client_instance # Call the health method health_response = await vllm_inference_adapter.health() + # Verify the response assert health_response["status"] == HealthStatus.ERROR assert "Health check failed: Connection failed" in health_response["message"] - mock_models.list.assert_called_once() + +async def test_health_status_no_static_api_key(vllm_inference_adapter): + """ + Test the health method of VLLM InferenceAdapter when no static API key is provided. + + This test verifies that the health method returns a HealthResponse with status OK + when the /health endpoint responds successfully, regardless of API token configuration. + """ + with patch("httpx.AsyncClient") as mock_client_class: + # Create mock response + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + + # Create mock client instance + mock_client_instance = MagicMock() + mock_client_instance.get = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance + + # Call the health method + health_response = await vllm_inference_adapter.health() + + # Verify the response + assert health_response["status"] == HealthStatus.OK async def test_openai_chat_completion_is_async(vllm_inference_adapter): @@ -656,3 +676,109 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_should_refresh_models(): + """ + Test the should_refresh_models method with different refresh_models configurations. + + This test verifies that: + 1. When refresh_models is True, should_refresh_models returns True regardless of api_token + 2. When refresh_models is False, should_refresh_models returns False regardless of api_token + """ + + # Test case 1: refresh_models is True, api_token is None + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) + adapter1 = VLLMInferenceAdapter(config1) + result1 = await adapter1.should_refresh_models() + assert result1 is True, "should_refresh_models should return True when refresh_models is True" + + # Test case 2: refresh_models is True, api_token is empty string + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) + adapter2 = VLLMInferenceAdapter(config2) + result2 = await adapter2.should_refresh_models() + assert result2 is True, "should_refresh_models should return True when refresh_models is True" + + # Test case 3: refresh_models is True, api_token is "fake" (default) + config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) + adapter3 = VLLMInferenceAdapter(config3) + result3 = await adapter3.should_refresh_models() + assert result3 is True, "should_refresh_models should return True when refresh_models is True" + + # Test case 4: refresh_models is True, api_token is real token + config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) + adapter4 = VLLMInferenceAdapter(config4) + result4 = await adapter4.should_refresh_models() + assert result4 is True, "should_refresh_models should return True when refresh_models is True" + + # Test case 5: refresh_models is False, api_token is real token + config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False) + adapter5 = VLLMInferenceAdapter(config5) + result5 = await adapter5.should_refresh_models() + assert result5 is False, "should_refresh_models should return False when refresh_models is False" + + +async def test_provider_data_var_context_propagation(vllm_inference_adapter): + """ + Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter. + This ensures that dynamic provider data (like API tokens) can be passed through context. + Note: The base URL is always taken from config.url, not from provider data. + """ + # Mock the AsyncOpenAI class to capture provider data + with ( + patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class, + patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data, + ): + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock provider data to return test data + mock_provider_data = MagicMock() + mock_provider_data.vllm_api_token = "test-token-123" + mock_provider_data.vllm_url = "http://test-server:8000/v1" + mock_get_provider_data.return_value = mock_provider_data + + # Mock the model + mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + try: + # Execute chat completion + await vllm_inference_adapter.chat_completion( + "test-model", + [UserMessage(content="Hello")], + stream=False, + tools=None, + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + # Verify that ALL client calls were made with the correct parameters + calls = mock_openai_class.call_args_list + incorrect_calls = [] + + for i, call in enumerate(calls): + api_key = call[1]["api_key"] + base_url = call[1]["base_url"] + + if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345": + incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url}) + + if incorrect_calls: + error_msg = ( + f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n" + ) + for incorrect_call in incorrect_calls: + error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n" + error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'" + raise AssertionError(error_msg) + + # Ensure at least one call was made + assert len(calls) >= 1, "No AsyncOpenAI client calls were made" + + # Verify that chat completion was called + mock_client.chat.completions.create.assert_called_once() + + finally: + # Clean up context + pass