mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: enforce allowed_models during inference requests (#4197)
The `allowed_models` configuration was only being applied when listing models via the `/v1/models` endpoint, but the actual inference requests weren't checking this restriction. This meant users could directly request any model the provider supports by specifying it in their inference call, completely bypassing the intended cost controls. The fix adds validation to all three inference methods (chat completions, completions, and embeddings) that checks the requested model against the allowed_models list before making the provider API call. ### Test plan Added unit tests
This commit is contained in:
parent
b6ce242808
commit
d649c3663e
2 changed files with 126 additions and 4 deletions
|
|
@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
return api_key
|
||||
|
||||
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
||||
"""
|
||||
Validate that the model is in the allowed_models list if configured.
|
||||
|
||||
:param provider_model_id: The provider-specific model ID to validate
|
||||
:raises ValueError: If the model is not in the allowed_models list
|
||||
"""
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{provider_model_id}' is not in the allowed models list. "
|
||||
f"Allowed models: {self.config.allowed_models}"
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
|
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
completion_kwargs = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
messages = params.messages
|
||||
|
||||
if self.download_images:
|
||||
|
|
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI embeddings API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
||||
request_params: dict[str, Any] = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"model": provider_model_id,
|
||||
"input": params.input,
|
||||
}
|
||||
if params.encoding_format is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue