mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 21:22:42 +00:00
Improve VLLM model discovery error handling
• Add comprehensive error handling in check_model_availability method • Provide helpful error messages with actionable solutions for 404 errors • Warn when API token is set but model discovery is disabled
This commit is contained in:
parent
e9214f9004
commit
e28bc93635
15 changed files with 69 additions and 50 deletions
|
|
@ -20,7 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
|||
| `api_token` | `str \| None` | No | fake | The API token |
|
||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
| `allow_listing_models` | `<class 'bool'>` | No | True | Whether to allow listing models from the vLLM server |
|
||||
| `enable_model_discovery` | `<class 'bool'>` | No | True | Whether to enable model discovery from the vLLM server |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
@ -29,5 +29,5 @@ url: ${env.VLLM_URL:=}
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -43,12 +43,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
# Check if providers allow listing models before returning models
|
||||
# Check if providers enable model discovery before returning models
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
allow_listing_models = await provider.allow_listing_models()
|
||||
logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}")
|
||||
if not allow_listing_models:
|
||||
logger.debug(f"Provider {provider_id} has allow_listing_models disabled")
|
||||
enable_model_discovery = await provider.enable_model_discovery()
|
||||
logger.debug(f"Provider {provider_id}: enable_model_discovery={enable_model_discovery}")
|
||||
if not enable_model_discovery:
|
||||
logger.debug(f"Provider {provider_id} has enable_model_discovery disabled")
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class MetaReferenceInferenceImpl(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
async def enable_model_discovery(self) -> bool:
|
||||
return True
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class SentenceTransformersInferenceImpl(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
async def enable_model_discovery(self) -> bool:
|
||||
return True
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
allow_listing_models: bool = Field(
|
||||
enable_model_discovery: bool = Field(
|
||||
default=True,
|
||||
description="Whether to allow listing models from the vLLM server",
|
||||
description="Whether to enable model discovery from the vLLM server",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
|
|
@ -63,5 +63,5 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
"allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}",
|
||||
"enable_model_discovery": "${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -282,16 +282,16 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
# Respecting the allow_listing_models directive
|
||||
result = self.config.allow_listing_models
|
||||
log.debug(f"VLLM allow_listing_models: {result}")
|
||||
async def enable_model_discovery(self) -> bool:
|
||||
# Respecting the enable_model_discovery directive
|
||||
result = self.config.enable_model_discovery
|
||||
log.debug(f"VLLM enable_model_discovery: {result}")
|
||||
return result
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}")
|
||||
if not self.config.allow_listing_models:
|
||||
log.debug("VLLM list_models returning None due to allow_listing_models=False")
|
||||
log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}")
|
||||
if not self.config.enable_model_discovery:
|
||||
log.debug("VLLM list_models returning None due to enable_model_discovery=False")
|
||||
return None
|
||||
|
||||
models = []
|
||||
|
|
@ -347,17 +347,22 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
"""
|
||||
Check if a specific model is available from the vLLM server.
|
||||
|
||||
This method respects the allow_listing_models configuration flag.
|
||||
If allow_listing_models is False, it returns True to allow model registration
|
||||
This method respects the enable_model_discovery configuration flag.
|
||||
If enable_model_discovery is False, it returns True to allow model registration
|
||||
without making HTTP requests (trusting that the model exists).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available or if allow_listing_models is False, False otherwise.
|
||||
:return: True if the model is available or if enable_model_discovery is False, False otherwise.
|
||||
"""
|
||||
# Check if provider allows listing models before making HTTP request
|
||||
if not self.config.allow_listing_models:
|
||||
log.debug(
|
||||
"VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)"
|
||||
# Check if provider enables model discovery before making HTTP request
|
||||
if not self.config.enable_model_discovery:
|
||||
log.debug("Model discovery disabled for vLLM: Trusting model exists")
|
||||
# Warn if API key is set but model discovery is disabled
|
||||
if self.config.api_token:
|
||||
log.warning(
|
||||
"Model discovery is disabled but VLLM_API_TOKEN is set. "
|
||||
"If you're not using model discovery, you may not need to set the API token. "
|
||||
"Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true."
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -367,7 +372,21 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
available_models = [m.id async for m in res]
|
||||
except Exception as e:
|
||||
# Provide helpful error message for model discovery failures
|
||||
log.error(f"Model discovery failed with the following output from vLLM server: {e}.\n")
|
||||
log.error(
|
||||
f"Model discovery failed: This typically occurs when a provider (like vLLM) is configured "
|
||||
f"with model discovery enabled but the provider server doesn't support the /models endpoint. "
|
||||
f"To resolve this, either:\n"
|
||||
f"1. Check that {self.config.url} correctly points to the vLLM server, or\n"
|
||||
f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n"
|
||||
f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n"
|
||||
)
|
||||
return False
|
||||
|
||||
is_available = model in available_models
|
||||
log.debug(f"VLLM model {model} availability: {is_available}")
|
||||
return is_available
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
async def enable_model_discovery(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
|
|
|
|||
|
|
@ -426,5 +426,5 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
async def enable_model_discovery(self) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class InferenceImpl(Impl):
|
|||
async def should_refresh_models(self):
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self):
|
||||
async def enable_model_discovery(self):
|
||||
return True
|
||||
|
||||
async def list_models(self):
|
||||
|
|
|
|||
|
|
@ -653,17 +653,17 @@ async def test_should_refresh_models():
|
|||
assert result2 is False, "should_refresh_models should return False when refresh_models is False"
|
||||
|
||||
|
||||
async def test_allow_listing_models_flag():
|
||||
async def test_enable_model_discovery_flag():
|
||||
"""
|
||||
Test the allow_listing_models flag functionality.
|
||||
Test the enable_model_discovery flag functionality.
|
||||
|
||||
This test verifies that:
|
||||
1. When allow_listing_models is True (default), list_models returns models from the server
|
||||
2. When allow_listing_models is False, list_models returns None without calling the server
|
||||
1. When enable_model_discovery is True (default), list_models returns models from the server
|
||||
2. When enable_model_discovery is False, list_models returns None without calling the server
|
||||
"""
|
||||
|
||||
# Test case 1: allow_listing_models is True (default)
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True)
|
||||
# Test case 1: enable_model_discovery is True (default)
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=True)
|
||||
adapter1 = VLLMInferenceAdapter(config1)
|
||||
adapter1.__provider_id__ = "test-vllm"
|
||||
|
||||
|
|
@ -679,14 +679,14 @@ async def test_allow_listing_models_flag():
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
models = await adapter1.list_models()
|
||||
assert models is not None, "list_models should return models when allow_listing_models is True"
|
||||
assert models is not None, "list_models should return models when enable_model_discovery is True"
|
||||
assert len(models) == 2, "Should return 2 models"
|
||||
assert models[0].identifier == "test-model-1"
|
||||
assert models[1].identifier == "test-model-2"
|
||||
mock_client.models.list.assert_called_once()
|
||||
|
||||
# Test case 2: allow_listing_models is False
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False)
|
||||
# Test case 2: enable_model_discovery is False
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=False)
|
||||
adapter2 = VLLMInferenceAdapter(config2)
|
||||
adapter2.__provider_id__ = "test-vllm"
|
||||
|
||||
|
|
@ -696,15 +696,15 @@ async def test_allow_listing_models_flag():
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
models = await adapter2.list_models()
|
||||
assert models is None, "list_models should return None when allow_listing_models is False"
|
||||
assert models is None, "list_models should return None when enable_model_discovery is False"
|
||||
mock_client.models.list.assert_not_called()
|
||||
|
||||
# Test case 3: allow_listing_models defaults to True
|
||||
# Test case 3: enable_model_discovery defaults to True
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost")
|
||||
adapter3 = VLLMInferenceAdapter(config3)
|
||||
adapter3.__provider_id__ = "test-vllm"
|
||||
result3 = await adapter3.allow_listing_models()
|
||||
assert result3 is True, "allow_listing_models should return True by default"
|
||||
result3 = await adapter3.enable_model_discovery()
|
||||
assert result3 is True, "enable_model_discovery should return True by default"
|
||||
|
||||
# Test case 4: refresh_models is True, api_token is real token
|
||||
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||
|
|
@ -729,7 +729,7 @@ async def test_allow_listing_models_flag():
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
models = await adapter3.list_models()
|
||||
assert models is not None, "list_models should return models when allow_listing_models defaults to True"
|
||||
assert models is not None, "list_models should return models when enable_model_discovery defaults to True"
|
||||
assert len(models) == 1, "Should return 1 model"
|
||||
assert models[0].identifier == "default-model"
|
||||
mock_client.models.list.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ async def test_setup(cached_disk_dist_registry):
|
|||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
mock_inference.allow_listing_models = AsyncMock(return_value=True)
|
||||
mock_inference.enable_model_discovery = AsyncMock(return_value=True)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue