From e28bc936351772135719e16ba8f5e6fd5506a85a Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Mon, 6 Oct 2025 12:56:05 +0200 Subject: [PATCH] Improve VLLM model discovery error handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add comprehensive error handling in check_model_availability method • Provide helpful error messages with actionable solutions for 404 errors • Warn when API token is set but model discovery is disabled --- docs/docs/providers/inference/remote_vllm.mdx | 4 +- llama_stack/core/routing_tables/models.py | 10 ++-- llama_stack/distributions/ci-tests/run.yaml | 2 +- .../distributions/postgres-demo/run.yaml | 2 +- .../distributions/starter-gpu/run.yaml | 2 +- llama_stack/distributions/starter/run.yaml | 2 +- .../inference/meta_reference/inference.py | 2 +- .../sentence_transformers.py | 2 +- .../providers/remote/inference/vllm/config.py | 6 +-- .../providers/remote/inference/vllm/vllm.py | 51 +++++++++++++------ .../utils/inference/model_registry.py | 2 +- .../providers/utils/inference/openai_mixin.py | 2 +- .../routers/test_routing_tables.py | 2 +- .../providers/inference/test_remote_vllm.py | 28 +++++----- tests/unit/server/test_access_control.py | 2 +- 15 files changed, 69 insertions(+), 50 deletions(-) diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index efa863016..884ca8922 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -20,7 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `allow_listing_models` | `` | No | True | Whether to allow listing models from the vLLM server | +| `enable_model_discovery` | `` | No | True | Whether to enable model discovery from the vLLM server | ## Sample Configuration @@ -29,5 +29,5 @@ url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} -allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} +enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} ``` diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 0a9850674..e75155158 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,12 +43,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - # Check if providers allow listing models before returning models + # Check if providers enable model discovery before returning models for provider_id, provider in self.impls_by_provider_id.items(): - allow_listing_models = await provider.allow_listing_models() - logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}") - if not allow_listing_models: - logger.debug(f"Provider {provider_id} has allow_listing_models disabled") + enable_model_discovery = await provider.enable_model_discovery() + logger.debug(f"Provider {provider_id}: enable_model_discovery={enable_model_discovery}") + if not enable_model_discovery: + logger.debug(f"Provider {provider_id} has enable_model_discovery disabled") return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index e70c11100..81c947f56 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 67691e5cf..98e784e76 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -16,7 +16,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers vector_io: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index fb29f7407..187e3ccde 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index d338944bb..d02bd439d 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 3c003bbca..f272040c0 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,7 +71,7 @@ class MetaReferenceInferenceImpl( async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True async def list_models(self) -> list[Model] | None: diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 542c4bceb..3dd5d2b89 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -52,7 +52,7 @@ class SentenceTransformersInferenceImpl( async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True async def list_models(self) -> list[Model] | None: diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 327718800..3887107dd 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -34,9 +34,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=False, description="Whether to refresh models periodically", ) - allow_listing_models: bool = Field( + enable_model_discovery: bool = Field( default=True, - description="Whether to allow listing models from the vLLM server", + description="Whether to enable model discovery from the vLLM server", ) @field_validator("tls_verify") @@ -63,5 +63,5 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "api_token": "${env.VLLM_API_TOKEN:=fake}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}", - "allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}", + "enable_model_discovery": "${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}", } diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 87a78c0f1..305990f06 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -282,16 +282,16 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro # Strictly respecting the refresh_models directive return self.config.refresh_models - async def allow_listing_models(self) -> bool: - # Respecting the allow_listing_models directive - result = self.config.allow_listing_models - log.debug(f"VLLM allow_listing_models: {result}") + async def enable_model_discovery(self) -> bool: + # Respecting the enable_model_discovery directive + result = self.config.enable_model_discovery + log.debug(f"VLLM enable_model_discovery: {result}") return result async def list_models(self) -> list[Model] | None: - log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}") - if not self.config.allow_listing_models: - log.debug("VLLM list_models returning None due to allow_listing_models=False") + log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}") + if not self.config.enable_model_discovery: + log.debug("VLLM list_models returning None due to enable_model_discovery=False") return None models = [] @@ -347,18 +347,23 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro """ Check if a specific model is available from the vLLM server. - This method respects the allow_listing_models configuration flag. - If allow_listing_models is False, it returns True to allow model registration + This method respects the enable_model_discovery configuration flag. + If enable_model_discovery is False, it returns True to allow model registration without making HTTP requests (trusting that the model exists). :param model: The model identifier to check. - :return: True if the model is available or if allow_listing_models is False, False otherwise. + :return: True if the model is available or if enable_model_discovery is False, False otherwise. """ - # Check if provider allows listing models before making HTTP request - if not self.config.allow_listing_models: - log.debug( - "VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)" - ) + # Check if provider enables model discovery before making HTTP request + if not self.config.enable_model_discovery: + log.debug("Model discovery disabled for vLLM: Trusting model exists") + # Warn if API key is set but model discovery is disabled + if self.config.api_token: + log.warning( + "Model discovery is disabled but VLLM_API_TOKEN is set. " + "If you're not using model discovery, you may not need to set the API token. " + "Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true." + ) return True try: @@ -367,7 +372,21 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}") return False - available_models = [m.id async for m in res] + try: + available_models = [m.id async for m in res] + except Exception as e: + # Provide helpful error message for model discovery failures + log.error(f"Model discovery failed with the following output from vLLM server: {e}.\n") + log.error( + f"Model discovery failed: This typically occurs when a provider (like vLLM) is configured " + f"with model discovery enabled but the provider server doesn't support the /models endpoint. " + f"To resolve this, either:\n" + f"1. Check that {self.config.url} correctly points to the vLLM server, or\n" + f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n" + f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n" + ) + return False + is_available = model in available_models log.debug(f"VLLM model {model} availability: {is_available}") return is_available diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4a4fa7adf..17c43fe38 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -100,7 +100,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True def get_provider_model_id(self, identifier: str) -> str | None: diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4164edfa3..a00a45963 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -426,5 +426,5 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 524d650da..ec7aca27e 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -52,7 +52,7 @@ class InferenceImpl(Impl): async def should_refresh_models(self): return False - async def allow_listing_models(self): + async def enable_model_discovery(self): return True async def list_models(self): diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6675a5901..701282179 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -653,17 +653,17 @@ async def test_should_refresh_models(): assert result2 is False, "should_refresh_models should return False when refresh_models is False" -async def test_allow_listing_models_flag(): +async def test_enable_model_discovery_flag(): """ - Test the allow_listing_models flag functionality. + Test the enable_model_discovery flag functionality. This test verifies that: - 1. When allow_listing_models is True (default), list_models returns models from the server - 2. When allow_listing_models is False, list_models returns None without calling the server + 1. When enable_model_discovery is True (default), list_models returns models from the server + 2. When enable_model_discovery is False, list_models returns None without calling the server """ - # Test case 1: allow_listing_models is True (default) - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True) + # Test case 1: enable_model_discovery is True (default) + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=True) adapter1 = VLLMInferenceAdapter(config1) adapter1.__provider_id__ = "test-vllm" @@ -679,14 +679,14 @@ async def test_allow_listing_models_flag(): mock_client_property.return_value = mock_client models = await adapter1.list_models() - assert models is not None, "list_models should return models when allow_listing_models is True" + assert models is not None, "list_models should return models when enable_model_discovery is True" assert len(models) == 2, "Should return 2 models" assert models[0].identifier == "test-model-1" assert models[1].identifier == "test-model-2" mock_client.models.list.assert_called_once() - # Test case 2: allow_listing_models is False - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False) + # Test case 2: enable_model_discovery is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=False) adapter2 = VLLMInferenceAdapter(config2) adapter2.__provider_id__ = "test-vllm" @@ -696,15 +696,15 @@ async def test_allow_listing_models_flag(): mock_client_property.return_value = mock_client models = await adapter2.list_models() - assert models is None, "list_models should return None when allow_listing_models is False" + assert models is None, "list_models should return None when enable_model_discovery is False" mock_client.models.list.assert_not_called() - # Test case 3: allow_listing_models defaults to True + # Test case 3: enable_model_discovery defaults to True config3 = VLLMInferenceAdapterConfig(url="http://test.localhost") adapter3 = VLLMInferenceAdapter(config3) adapter3.__provider_id__ = "test-vllm" - result3 = await adapter3.allow_listing_models() - assert result3 is True, "allow_listing_models should return True by default" + result3 = await adapter3.enable_model_discovery() + assert result3 is True, "enable_model_discovery should return True by default" # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) @@ -729,7 +729,7 @@ async def test_allow_listing_models_flag(): mock_client_property.return_value = mock_client models = await adapter3.list_models() - assert models is not None, "list_models should return models when allow_listing_models defaults to True" + assert models is not None, "list_models should return models when enable_model_discovery defaults to True" assert len(models) == 1, "Should return 1 model" assert models[0].identifier == "default-model" mock_client.models.list.assert_called_once() diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 8752dfc28..3cb393e0e 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -32,7 +32,7 @@ async def test_setup(cached_disk_dist_registry): mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) - mock_inference.allow_listing_models = AsyncMock(return_value=True) + mock_inference.enable_model_discovery = AsyncMock(return_value=True) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry,