fix: enforce allowed_models during inference requests (backport #4197) (#4228)

The `allowed_models` configuration was only being applied when listing
models via the `/v1/models` endpoint, but the actual inference requests
weren't checking this restriction. This meant users could directly
request any model the provider supports by specifying it in their
inference call, completely bypassing the intended cost controls.

The fix adds validation to all three inference methods (chat
completions, completions, and embeddings) that checks the requested
model against the allowed_models list before making the provider API
call.

### Test plan

Added unit tests <hr>This is an automatic backport of pull request #4197
done by [Mergify](https://mergify.com).

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
mergify[bot] 2025-11-24 11:31:36 -08:00 committed by GitHub
parent 0df6d4601f
commit 05b4394cf9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 128 additions and 6 deletions

View file

@ -188,6 +188,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
return api_key return api_key
def _validate_model_allowed(self, provider_model_id: str) -> None:
"""
Validate that the model is in the allowed_models list if configured.
:param provider_model_id: The provider-specific model ID to validate
:raises ValueError: If the model is not in the allowed_models list
"""
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
raise ValueError(
f"Model '{provider_model_id}' is not in the allowed models list. "
f"Allowed models: {self.config.allowed_models}"
)
async def _get_provider_model_id(self, model: str) -> str: async def _get_provider_model_id(self, model: str) -> str:
""" """
Get the provider-specific model ID from the model store. Get the provider-specific model ID from the model store.
@ -234,8 +247,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
completion_kwargs = await prepare_openai_completion_params( completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
prompt=params.prompt, prompt=params.prompt,
best_of=params.best_of, best_of=params.best_of,
echo=params.echo, echo=params.echo,
@ -267,6 +283,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
messages = params.messages messages = params.messages
if self.download_images: if self.download_images:
@ -288,7 +307,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages] messages = [await _localize_image_url(m) for m in messages]
request_params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
messages=messages, messages=messages,
frequency_penalty=params.frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=params.function_call, function_call=params.function_call,
@ -326,9 +345,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI embeddings API call. Direct OpenAI embeddings API call.
""" """
# Prepare request parameters provider_model_id = await self._get_provider_model_id(params.model)
request_params = { self._validate_model_allowed(provider_model_id)
"model": await self._get_provider_model_id(params.model),
# Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
request_params: dict[str, Any] = {
"model": provider_model_id,
"input": params.input, "input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN, "encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN, "dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,

View file

@ -12,7 +12,13 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam from llama_stack.apis.inference import (
Model,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -743,3 +749,96 @@ class TestOpenAIMixinProviderDataApiKey:
error_message = str(exc_info.value) error_message = str(exc_info.value)
assert "test_api_key" in error_message assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message assert "x-llamastack-provider-data" in error_message
class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()
# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]
mock_client = MagicMock()
with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)
# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)
mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)