feat: Add allow_listing_models

• Add allow_listing_models configuration flag to VLLM provider to control model listing behavior
• Implement allow_listing_models() method across all providers with default implementations in base classes
• Prevent HTTP requests to /v1/models endpoint when allow_listing_models=false for improved security and performance
• Fix unit tests to include allow_listing_models method in test classes and mock objects
This commit is contained in:
Akram Ben Aissi 2025-10-04 00:17:53 +02:00
parent 188a56af5c
commit e9214f9004
15 changed files with 143 additions and 25 deletions

View file

@ -636,27 +636,75 @@ async def test_should_refresh_models():
Test the should_refresh_models method with different refresh_models configurations.
This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
1. When refresh_models is True, should_refresh_models returns True
2. When refresh_models is False, should_refresh_models returns False
"""
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
# Test case 1: refresh_models is True
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
# Test case 2: refresh_models is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
assert result2 is False, "should_refresh_models should return False when refresh_models is False"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
async def test_allow_listing_models_flag():
"""
Test the allow_listing_models flag functionality.
This test verifies that:
1. When allow_listing_models is True (default), list_models returns models from the server
2. When allow_listing_models is False, list_models returns None without calling the server
"""
# Test case 1: allow_listing_models is True (default)
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1.__provider_id__ = "test-vllm"
# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
async def mock_models_list():
yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test")
yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test")
mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client
models = await adapter1.list_models()
assert models is not None, "list_models should return models when allow_listing_models is True"
assert len(models) == 2, "Should return 2 models"
assert models[0].identifier == "test-model-1"
assert models[1].identifier == "test-model-2"
mock_client.models.list.assert_called_once()
# Test case 2: allow_listing_models is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False)
adapter2 = VLLMInferenceAdapter(config2)
adapter2.__provider_id__ = "test-vllm"
# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client_property.return_value = mock_client
models = await adapter2.list_models()
assert models is None, "list_models should return None when allow_listing_models is False"
mock_client.models.list.assert_not_called()
# Test case 3: allow_listing_models defaults to True
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost")
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
adapter3.__provider_id__ = "test-vllm"
result3 = await adapter3.allow_listing_models()
assert result3 is True, "allow_listing_models should return True by default"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
@ -670,6 +718,22 @@ async def test_should_refresh_models():
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
async def mock_models_list():
yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test")
mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client
models = await adapter3.list_models()
assert models is not None, "list_models should return models when allow_listing_models defaults to True"
assert len(models) == 1, "Should return 1 model"
assert models[0].identifier == "default-model"
mock_client.models.list.assert_called_once()
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""