From e9214f9004b9d08b94fa2ffe88c3da74f0a88fc5 Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Sat, 4 Oct 2025 00:17:53 +0200 Subject: [PATCH] feat: Add allow_listing_models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add allow_listing_models configuration flag to VLLM provider to control model listing behavior • Implement allow_listing_models() method across all providers with default implementations in base classes • Prevent HTTP requests to /v1/models endpoint when allow_listing_models=false for improved security and performance • Fix unit tests to include allow_listing_models method in test classes and mock objects --- docs/docs/providers/inference/remote_vllm.mdx | 2 + llama_stack/core/routing_tables/models.py | 6 ++ llama_stack/distributions/ci-tests/run.yaml | 1 + .../distributions/postgres-demo/run.yaml | 1 + .../distributions/starter-gpu/run.yaml | 1 + llama_stack/distributions/starter/run.yaml | 1 + .../inference/meta_reference/inference.py | 3 + .../sentence_transformers.py | 3 + .../providers/remote/inference/vllm/config.py | 5 ++ .../providers/remote/inference/vllm/vllm.py | 49 ++++++++--- .../utils/inference/model_registry.py | 3 + .../providers/utils/inference/openai_mixin.py | 3 + .../routers/test_routing_tables.py | 3 + .../providers/inference/test_remote_vllm.py | 86 ++++++++++++++++--- tests/unit/server/test_access_control.py | 1 + 15 files changed, 143 insertions(+), 25 deletions(-) diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 598f97b19..efa863016 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -20,6 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | +| `allow_listing_models` | `` | No | True | Whether to allow listing models from the vLLM server | ## Sample Configuration @@ -28,4 +29,5 @@ url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} +allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} ``` diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 641c73c16..0a9850674 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,6 +43,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: + # Check if providers allow listing models before returning models + for provider_id, provider in self.impls_by_provider_id.items(): + allow_listing_models = await provider.allow_listing_models() + logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}") + if not allow_listing_models: + logger.debug(f"Provider {provider_id} has allow_listing_models disabled") return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index b14477a9a..e70c11100 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 0cf0e82e6..67691e5cf 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -16,6 +16,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers vector_io: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index de5fe5681..fb29f7407 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index c440e4e4b..d338944bb 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index fd65fa10d..3c003bbca 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,6 +71,9 @@ class MetaReferenceInferenceImpl( async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return None diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf..542c4bceb 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -52,6 +52,9 @@ class SentenceTransformersInferenceImpl( async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return [ Model( diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 86ef3fe26..327718800 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=False, description="Whether to refresh models periodically", ) + allow_listing_models: bool = Field( + default=True, + description="Whether to allow listing models from the vLLM server", + ) @field_validator("tls_verify") @classmethod @@ -59,4 +63,5 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "api_token": "${env.VLLM_API_TOKEN:=fake}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}", + "allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}", } diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 54ac8e1dc..87a78c0f1 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -282,7 +282,18 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro # Strictly respecting the refresh_models directive return self.config.refresh_models + async def allow_listing_models(self) -> bool: + # Respecting the allow_listing_models directive + result = self.config.allow_listing_models + log.debug(f"VLLM allow_listing_models: {result}") + return result + async def list_models(self) -> list[Model] | None: + log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}") + if not self.config.allow_listing_models: + log.debug("VLLM list_models returning None due to allow_listing_models=False") + return None + models = [] async for m in self.client.models.list(): model_type = ModelType.llm # unclear how to determine embedding vs. llm models @@ -332,24 +343,34 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def register_model(self, model: Model) -> Model: - try: - model = await self.register_helper.register_model(model) - except ValueError: - pass # Ignore statically unknown model, will check live listing + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the vLLM server. + + This method respects the allow_listing_models configuration flag. + If allow_listing_models is False, it returns True to allow model registration + without making HTTP requests (trusting that the model exists). + + :param model: The model identifier to check. + :return: True if the model is available or if allow_listing_models is False, False otherwise. + """ + # Check if provider allows listing models before making HTTP request + if not self.config.allow_listing_models: + log.debug( + "VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)" + ) + return True + try: res = self.client.models.list() except APIConnectionError as e: - raise ValueError( - f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." - ) from e + log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}") + return False + available_models = [m.id async for m in res] - if model.provider_resource_id not in available_models: - raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM. " - f"Available models: {', '.join(available_models)}" - ) - return model + is_available = model in available_models + log.debug(f"VLLM model {model} availability: {is_available}") + return is_available async def _get_params(self, request: ChatCompletionRequest) -> dict: options = get_sampling_options(request.sampling_params) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4913c2e1f..4a4fa7adf 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -100,6 +100,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate): async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e..4164edfa3 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -425,3 +425,6 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): async def should_refresh_models(self) -> bool: return False + + async def allow_listing_models(self) -> bool: + return True diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e..524d650da 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -52,6 +52,9 @@ class InferenceImpl(Impl): async def should_refresh_models(self): return False + async def allow_listing_models(self): + return True + async def list_models(self): return [ Model( diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index cd31e4943..6675a5901 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -636,27 +636,75 @@ async def test_should_refresh_models(): Test the should_refresh_models method with different refresh_models configurations. This test verifies that: - 1. When refresh_models is True, should_refresh_models returns True regardless of api_token - 2. When refresh_models is False, should_refresh_models returns False regardless of api_token + 1. When refresh_models is True, should_refresh_models returns True + 2. When refresh_models is False, should_refresh_models returns False """ - # Test case 1: refresh_models is True, api_token is None - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) + # Test case 1: refresh_models is True + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True) adapter1 = VLLMInferenceAdapter(config1) result1 = await adapter1.should_refresh_models() assert result1 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 2: refresh_models is True, api_token is empty string - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) + # Test case 2: refresh_models is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False) adapter2 = VLLMInferenceAdapter(config2) result2 = await adapter2.should_refresh_models() - assert result2 is True, "should_refresh_models should return True when refresh_models is True" + assert result2 is False, "should_refresh_models should return False when refresh_models is False" - # Test case 3: refresh_models is True, api_token is "fake" (default) - config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) + +async def test_allow_listing_models_flag(): + """ + Test the allow_listing_models flag functionality. + + This test verifies that: + 1. When allow_listing_models is True (default), list_models returns models from the server + 2. When allow_listing_models is False, list_models returns None without calling the server + """ + + # Test case 1: allow_listing_models is True (default) + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True) + adapter1 = VLLMInferenceAdapter(config1) + adapter1.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test") + yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter1.list_models() + assert models is not None, "list_models should return models when allow_listing_models is True" + assert len(models) == 2, "Should return 2 models" + assert models[0].identifier == "test-model-1" + assert models[1].identifier == "test-model-2" + mock_client.models.list.assert_called_once() + + # Test case 2: allow_listing_models is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False) + adapter2 = VLLMInferenceAdapter(config2) + adapter2.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client_property.return_value = mock_client + + models = await adapter2.list_models() + assert models is None, "list_models should return None when allow_listing_models is False" + mock_client.models.list.assert_not_called() + + # Test case 3: allow_listing_models defaults to True + config3 = VLLMInferenceAdapterConfig(url="http://test.localhost") adapter3 = VLLMInferenceAdapter(config3) - result3 = await adapter3.should_refresh_models() - assert result3 is True, "should_refresh_models should return True when refresh_models is True" + adapter3.__provider_id__ = "test-vllm" + result3 = await adapter3.allow_listing_models() + assert result3 is True, "allow_listing_models should return True by default" # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) @@ -670,6 +718,22 @@ async def test_should_refresh_models(): result5 = await adapter5.should_refresh_models() assert result5 is False, "should_refresh_models should return False when refresh_models is False" + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter3.list_models() + assert models is not None, "list_models should return models when allow_listing_models defaults to True" + assert len(models) == 1, "Should return 1 model" + assert models[0].identifier == "default-model" + mock_client.models.list.assert_called_once() + async def test_provider_data_var_context_propagation(vllm_inference_adapter): """ diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 55449804a..8752dfc28 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -32,6 +32,7 @@ async def test_setup(cached_disk_dist_registry): mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) + mock_inference.allow_listing_models = AsyncMock(return_value=True) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry,