From 1a25a178366122f29c96b9e5ac606fd6e18ee27c Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Wed, 17 Sep 2025 19:49:38 +0200 Subject: [PATCH] Taking into account review: ignore, asserts, pytest.fixture Signed-off-by: Akram Ben Aissi --- .../providers/remote/inference/vllm/vllm.py | 20 ++++++------- .../providers/inference/test_remote_vllm.py | 30 ++++++++----------- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index a83ec74a3..b4079c39f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -304,11 +304,10 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro get_api_key = LiteLLMOpenAIMixin.get_api_key def get_base_url(self) -> str: - """Get the base URL, falling back to the api_base from LiteLLMOpenAIMixin or config.""" - url = self.api_base or self.config.url - if not url: + """Get the base URL from config.""" + if not self.config.url: raise ValueError("No base URL configured") - return url + return self.config.url async def initialize(self) -> None: if not self.config.url: @@ -317,10 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro ) async def should_refresh_models(self) -> bool: - # Get the default value from the field definition - default_api_token = self.config.__class__.model_fields["api_token"].default - if not self.config.api_token or self.config.api_token == default_api_token: - return False + # Strictly respecting the refresh_models directive return self.config.refresh_models async def list_models(self) -> list[Model] | None: @@ -373,7 +369,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def completion( # type: ignore[override] + async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses. self, model_id: str, content: InterleavedContent, @@ -485,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: - assert self.client is not None + if self.client is None: + raise RuntimeError("Client is not initialized") params = await self._get_params(request) r = await self.client.completions.create(**params) return process_completion_response(r) @@ -493,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro async def _stream_completion( self, request: CompletionRequest ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - assert self.client is not None + if self.client is None: + raise RuntimeError("Client is not initialized") params = await self._get_params(request) stream = await self.client.completions.create(**params) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 7bb88c51c..9545e0cf6 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -680,46 +680,42 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): async def test_should_refresh_models(): """ - Test the should_refresh_models method with different api_token configurations. + Test the should_refresh_models method with different refresh_models configurations. This test verifies that: - 1. When api_token is None or empty, should_refresh_models returns False - 2. When api_token is "fake" (default), should_refresh_models returns False - 3. When api_token is a real token and refresh_models is True, should_refresh_models returns True - 4. When api_token is a real token and refresh_models is False, should_refresh_models returns False + 1. When refresh_models is True, should_refresh_models returns True regardless of api_token + 2. When refresh_models is False, should_refresh_models returns False regardless of api_token """ - # Test case 1: api_token is None, refresh_models is True + # Test case 1: refresh_models is True, api_token is None config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) adapter1 = VLLMInferenceAdapter(config1) result1 = await adapter1.should_refresh_models() - assert result1 is False, "should_refresh_models should return False when api_token is None" + assert result1 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 2: api_token is empty string, refresh_models is True + # Test case 2: refresh_models is True, api_token is empty string config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) adapter2 = VLLMInferenceAdapter(config2) result2 = await adapter2.should_refresh_models() - assert result2 is False, "should_refresh_models should return False when api_token is empty" + assert result2 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 3: api_token is "fake" (default), refresh_models is True + # Test case 3: refresh_models is True, api_token is "fake" (default) config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) adapter3 = VLLMInferenceAdapter(config3) result3 = await adapter3.should_refresh_models() - assert result3 is False, "should_refresh_models should return False when api_token is 'fake'" + assert result3 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 4: api_token is real token, refresh_models is True + # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) adapter4 = VLLMInferenceAdapter(config4) result4 = await adapter4.should_refresh_models() - assert result4 is True, "should_refresh_models should return True when api_token is real and refresh_models is True" + assert result4 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 5: api_token is real token, refresh_models is False + # Test case 5: refresh_models is False, api_token is real token config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False) adapter5 = VLLMInferenceAdapter(config5) result5 = await adapter5.should_refresh_models() - assert result5 is False, ( - "should_refresh_models should return False when api_token is real but refresh_models is False" - ) + assert result5 is False, "should_refresh_models should return False when refresh_models is False" async def test_provider_data_var_context_propagation(vllm_inference_adapter):