mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix: enforce allowed_models during inference requests
The `allowed_models` configuration was only filtering the model list endpoint but not enforcing restrictions during actual inference requests. This allowed users to bypass the restriction by directly requesting models not in the allowed list, potentially accessing expensive models when only cheaper ones were intended. This change adds validation to all inference methods (`openai_chat_completion`, `openai_completion`, `openai_embeddings`) to reject requests for disallowed models with a clear error message. **Implementation:** - Added `_validate_model_allowed()` helper method that checks if a model is in the `allowed_models` list - Called validation in all three inference methods before making API requests - Validation occurs after resolving the provider model ID to ensure consistency **Test Plan:** - Added unit tests verifying all inference methods respect `allowed_models` - Tests cover allowed models (success), disallowed models (rejection), and no restrictions (None allows all, empty list blocks all) - All existing tests continue to pass Fixes GHSA-5rjj-4jp6-fw39
This commit is contained in:
parent
8852666982
commit
db6488b379
2 changed files with 126 additions and 4 deletions
|
|
@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
return api_key
|
||||
|
||||
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
||||
"""
|
||||
Validate that the model is in the allowed_models list if configured.
|
||||
|
||||
:param provider_model_id: The provider-specific model ID to validate
|
||||
:raises ValueError: If the model is not in the allowed_models list
|
||||
"""
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{provider_model_id}' is not in the allowed models list. "
|
||||
f"Allowed models: {self.config.allowed_models}"
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
|
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
completion_kwargs = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
messages = params.messages
|
||||
|
||||
if self.download_images:
|
||||
|
|
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI embeddings API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
||||
request_params: dict[str, Any] = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"model": provider_model_id,
|
||||
"input": params.input,
|
||||
}
|
||||
if params.encoding_format is not None:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
|
|
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModelsInference:
|
||||
"""Test cases for allowed_models enforcement during inference requests"""
|
||||
|
||||
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods succeed with allowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_embedding_response = MagicMock()
|
||||
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test completion
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||
)
|
||||
mock_client.completions.create.assert_called_once()
|
||||
|
||||
# Test embeddings
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||
)
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods fail with disallowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
# Test completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||
)
|
||||
|
||||
# Test embeddings with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
mock_client.completions.create.assert_not_called()
|
||||
mock_client.embeddings.create.assert_not_called()
|
||||
|
||||
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||
# Test with None (no restrictions)
|
||||
assert mixin.config.allowed_models is None
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test with empty list (blocks all models)
|
||||
mixin.config.allowed_models = []
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue