fix: enforce allowed_models during inference requests

The `allowed_models` configuration was only filtering the model list endpoint but not enforcing restrictions during actual inference requests. This allowed users to bypass the restriction by directly requesting models not in the allowed list, potentially accessing expensive models when only cheaper ones were intended.

This change adds validation to all inference methods (`openai_chat_completion`, `openai_completion`, `openai_embeddings`) to reject requests for disallowed models with a clear error message.

**Implementation:**
- Added `_validate_model_allowed()` helper method that checks if a model is in the `allowed_models` list
- Called validation in all three inference methods before making API requests
- Validation occurs after resolving the provider model ID to ensure consistency

**Test Plan:**
- Added unit tests verifying all inference methods respect `allowed_models`
- Tests cover allowed models (success), disallowed models (rejection), and no restrictions (None allows all, empty list blocks all)
- All existing tests continue to pass

Fixes GHSA-5rjj-4jp6-fw39
This commit is contained in:
Ashwin Bharambe 2025-11-19 12:12:28 -08:00
parent 8852666982
commit db6488b379
2 changed files with 126 additions and 4 deletions

View file

@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
return api_key return api_key
def _validate_model_allowed(self, provider_model_id: str) -> None:
"""
Validate that the model is in the allowed_models list if configured.
:param provider_model_id: The provider-specific model ID to validate
:raises ValueError: If the model is not in the allowed_models list
"""
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
raise ValueError(
f"Model '{provider_model_id}' is not in the allowed models list. "
f"Allowed models: {self.config.allowed_models}"
)
async def _get_provider_model_id(self, model: str) -> str: async def _get_provider_model_id(self, model: str) -> str:
""" """
Get the provider-specific model ID from the model store. Get the provider-specific model ID from the model store.
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
completion_kwargs = await prepare_openai_completion_params( completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
prompt=params.prompt, prompt=params.prompt,
best_of=params.best_of, best_of=params.best_of,
echo=params.echo, echo=params.echo,
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
messages = params.messages messages = params.messages
if self.download_images: if self.download_images:
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages] messages = [await _localize_image_url(m) for m in messages]
request_params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
messages=messages, messages=messages,
frequency_penalty=params.frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=params.function_call, function_call=params.function_call,
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI embeddings API call. Direct OpenAI embeddings API call.
""" """
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
# Build request params conditionally to avoid NotGiven/Omit type mismatch # Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
request_params: dict[str, Any] = { request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model), "model": provider_model_id,
"input": params.input, "input": params.input,
} }
if params.encoding_format is not None: if params.encoding_format is not None:

View file

@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam from llama_stack_api import (
Model,
ModelType,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)
class OpenAIMixinImpl(OpenAIMixin): class OpenAIMixinImpl(OpenAIMixin):
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
error_message = str(exc_info.value) error_message = str(exc_info.value)
assert "test_api_key" in error_message assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message assert "x-llamastack-provider-data" in error_message
class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()
# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]
mock_client = MagicMock()
with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)
# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)
mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)