feat: Add allow_listing_models

• Add allow_listing_models configuration flag to VLLM provider to control model listing behavior
• Implement allow_listing_models() method across all providers with default implementations in base classes
• Prevent HTTP requests to /v1/models endpoint when allow_listing_models=false for improved security and performance
• Fix unit tests to include allow_listing_models method in test classes and mock objects
This commit is contained in:
Akram Ben Aissi 2025-10-04 00:17:53 +02:00
parent 188a56af5c
commit e9214f9004
15 changed files with 143 additions and 25 deletions

View file

@ -43,6 +43,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
# Check if providers allow listing models before returning models
for provider_id, provider in self.impls_by_provider_id.items():
allow_listing_models = await provider.allow_listing_models()
logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}")
if not allow_listing_models:
logger.debug(f"Provider {provider_id} has allow_listing_models disabled")
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:

View file

@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:

View file

@ -16,6 +16,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:

View file

@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:

View file

@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:

View file

@ -71,6 +71,9 @@ class MetaReferenceInferenceImpl(
async def should_refresh_models(self) -> bool:
return False
async def allow_listing_models(self) -> bool:
return True
async def list_models(self) -> list[Model] | None:
return None

View file

@ -52,6 +52,9 @@ class SentenceTransformersInferenceImpl(
async def should_refresh_models(self) -> bool:
return False
async def allow_listing_models(self) -> bool:
return True
async def list_models(self) -> list[Model] | None:
return [
Model(

View file

@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=False,
description="Whether to refresh models periodically",
)
allow_listing_models: bool = Field(
default=True,
description="Whether to allow listing models from the vLLM server",
)
@field_validator("tls_verify")
@classmethod
@ -59,4 +63,5 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
"allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}",
}

View file

@ -282,7 +282,18 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def allow_listing_models(self) -> bool:
# Respecting the allow_listing_models directive
result = self.config.allow_listing_models
log.debug(f"VLLM allow_listing_models: {result}")
return result
async def list_models(self) -> list[Model] | None:
log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}")
if not self.config.allow_listing_models:
log.debug("VLLM list_models returning None due to allow_listing_models=False")
return None
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
@ -332,24 +343,34 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the vLLM server.
This method respects the allow_listing_models configuration flag.
If allow_listing_models is False, it returns True to allow model registration
without making HTTP requests (trusting that the model exists).
:param model: The model identifier to check.
:return: True if the model is available or if allow_listing_models is False, False otherwise.
"""
# Check if provider allows listing models before making HTTP request
if not self.config.allow_listing_models:
log.debug(
"VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)"
)
return True
try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}")
return False
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
is_available = model in available_models
log.debug(f"VLLM model {model} availability: {is_available}")
return is_available
async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)

View file

@ -100,6 +100,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
async def should_refresh_models(self) -> bool:
return False
async def allow_listing_models(self) -> bool:
return True
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -425,3 +425,6 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def should_refresh_models(self) -> bool:
return False
async def allow_listing_models(self) -> bool:
return True