Taking into account review: ignore, asserts, pytest.fixture

Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
Akram Ben Aissi 2025-09-17 19:49:38 +02:00
parent 5cc605deb5
commit 1a25a17836
2 changed files with 22 additions and 28 deletions

View file

@ -304,11 +304,10 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
get_api_key = LiteLLMOpenAIMixin.get_api_key get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get the base URL, falling back to the api_base from LiteLLMOpenAIMixin or config.""" """Get the base URL from config."""
url = self.api_base or self.config.url if not self.config.url:
if not url:
raise ValueError("No base URL configured") raise ValueError("No base URL configured")
return url return self.config.url
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: if not self.config.url:
@ -317,10 +316,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
) )
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
# Get the default value from the field definition # Strictly respecting the refresh_models directive
default_api_token = self.config.__class__.model_fields["api_token"].default
if not self.config.api_token or self.config.api_token == default_api_token:
return False
return self.config.refresh_models return self.config.refresh_models
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
@ -373,7 +369,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
def get_extra_client_params(self): def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion( # type: ignore[override] async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
@ -485,7 +481,8 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.completions.create(**params) r = await self.client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
@ -493,7 +490,8 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
async def _stream_completion( async def _stream_completion(
self, request: CompletionRequest self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]: ) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request) params = await self._get_params(request)
stream = await self.client.completions.create(**params) stream = await self.client.completions.create(**params)

View file

@ -680,46 +680,42 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
async def test_should_refresh_models(): async def test_should_refresh_models():
""" """
Test the should_refresh_models method with different api_token configurations. Test the should_refresh_models method with different refresh_models configurations.
This test verifies that: This test verifies that:
1. When api_token is None or empty, should_refresh_models returns False 1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When api_token is "fake" (default), should_refresh_models returns False 2. When refresh_models is False, should_refresh_models returns False regardless of api_token
3. When api_token is a real token and refresh_models is True, should_refresh_models returns True
4. When api_token is a real token and refresh_models is False, should_refresh_models returns False
""" """
# Test case 1: api_token is None, refresh_models is True # Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1) adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models() result1 = await adapter1.should_refresh_models()
assert result1 is False, "should_refresh_models should return False when api_token is None" assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: api_token is empty string, refresh_models is True # Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2) adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models() result2 = await adapter2.should_refresh_models()
assert result2 is False, "should_refresh_models should return False when api_token is empty" assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: api_token is "fake" (default), refresh_models is True # Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3) adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models() result3 = await adapter3.should_refresh_models()
assert result3 is False, "should_refresh_models should return False when api_token is 'fake'" assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: api_token is real token, refresh_models is True # Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4) adapter4 = VLLMInferenceAdapter(config4)
result4 = await adapter4.should_refresh_models() result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when api_token is real and refresh_models is True" assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: api_token is real token, refresh_models is False # Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False) config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5) adapter5 = VLLMInferenceAdapter(config5)
result5 = await adapter5.should_refresh_models() result5 = await adapter5.should_refresh_models()
assert result5 is False, ( assert result5 is False, "should_refresh_models should return False when refresh_models is False"
"should_refresh_models should return False when api_token is real but refresh_models is False"
)
async def test_provider_data_var_context_propagation(vllm_inference_adapter): async def test_provider_data_var_context_propagation(vllm_inference_adapter):