mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 08:02:37 +00:00
feat: Add allow_listing_models
• Add allow_listing_models configuration flag to VLLM provider to control model listing behavior • Implement allow_listing_models() method across all providers with default implementations in base classes • Prevent HTTP requests to /v1/models endpoint when allow_listing_models=false for improved security and performance • Fix unit tests to include allow_listing_models method in test classes and mock objects
This commit is contained in:
parent
188a56af5c
commit
e9214f9004
15 changed files with 143 additions and 25 deletions
|
|
@ -43,6 +43,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
# Check if providers allow listing models before returning models
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
allow_listing_models = await provider.allow_listing_models()
|
||||
logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}")
|
||||
if not allow_listing_models:
|
||||
logger.debug(f"Provider {provider_id} has allow_listing_models disabled")
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ providers:
|
|||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,9 @@ class MetaReferenceInferenceImpl(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ class SentenceTransformersInferenceImpl(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
|
|
|
|||
|
|
@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
allow_listing_models: bool = Field(
|
||||
default=True,
|
||||
description="Whether to allow listing models from the vLLM server",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
@ -59,4 +63,5 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
"allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -282,7 +282,18 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
# Respecting the allow_listing_models directive
|
||||
result = self.config.allow_listing_models
|
||||
log.debug(f"VLLM allow_listing_models: {result}")
|
||||
return result
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}")
|
||||
if not self.config.allow_listing_models:
|
||||
log.debug("VLLM list_models returning None due to allow_listing_models=False")
|
||||
return None
|
||||
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
|
|
@ -332,24 +343,34 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the vLLM server.
|
||||
|
||||
This method respects the allow_listing_models configuration flag.
|
||||
If allow_listing_models is False, it returns True to allow model registration
|
||||
without making HTTP requests (trusting that the model exists).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available or if allow_listing_models is False, False otherwise.
|
||||
"""
|
||||
# Check if provider allows listing models before making HTTP request
|
||||
if not self.config.allow_listing_models:
|
||||
log.debug(
|
||||
"VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
res = self.client.models.list()
|
||||
except APIConnectionError as e:
|
||||
raise ValueError(
|
||||
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
||||
) from e
|
||||
log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}")
|
||||
return False
|
||||
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
||||
is_available = model in available_models
|
||||
log.debug(f"VLLM model {model} availability: {is_available}")
|
||||
return is_available
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
|
|
|
|||
|
|
@ -425,3 +425,6 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
|
|||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def allow_listing_models(self) -> bool:
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue