diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 0a283780f..e7ec13671 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -191,6 +191,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): return api_key + def _validate_model_allowed(self, provider_model_id: str) -> None: + """ + Validate that the model is in the allowed_models list if configured. + + :param provider_model_id: The provider-specific model ID to validate + :raises ValueError: If the model is not in the allowed_models list + """ + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: + raise ValueError( + f"Model '{provider_model_id}' is not in the allowed models list. " + f"Allowed models: {self.config.allowed_models}" + ) + async def _get_provider_model_id(self, model: str) -> str: """ Get the provider-specific model ID from the model store. @@ -237,8 +250,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): Direct OpenAI completion API call. """ # TODO: fix openai_completion to return type compatible with OpenAI's API response + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + completion_kwargs = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -270,6 +286,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI chat completion API call. """ + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + messages = params.messages if self.download_images: @@ -291,7 +310,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): messages = [await _localize_image_url(m) for m in messages] request_params = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, messages=messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -329,9 +348,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ Direct OpenAI embeddings API call. """ - # Prepare request parameters - request_params = { - "model": await self._get_provider_model_id(params.model), + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + + # Build request params conditionally to avoid NotGiven/Omit type mismatch + # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven + request_params: dict[str, Any] = { + "model": provider_model_id, "input": params.input, "encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN, "dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN, diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 61a1f8f61..37b852709 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -12,7 +12,13 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest from pydantic import BaseModel, Field -from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam +from llama_stack.apis.inference import ( + Model, + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIUserMessageParam, +) from llama_stack.apis.models import ModelType from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig @@ -733,3 +739,96 @@ class TestOpenAIMixinProviderDataApiKey: error_message = str(exc_info.value) assert "test_api_key" in error_message assert "x-llamastack-provider-data" in error_message + + +class TestOpenAIMixinAllowedModelsInference: + """Test cases for allowed_models enforcement during inference requests""" + + async def test_inference_with_allowed_models(self, mixin, mock_client_context): + """Test that all inference methods succeed with allowed models""" + mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + mock_embedding_response = MagicMock() + mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) + + with mock_client_context(mixin, mock_client): + # Test chat completion + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test completion + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") + ) + mock_client.completions.create.assert_called_once() + + # Test embeddings + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") + ) + mock_client.embeddings.create.assert_called_once() + + async def test_inference_with_disallowed_models(self, mixin, mock_client_context): + """Test that all inference methods fail with disallowed models""" + mixin.config.allowed_models = ["gpt-4"] + + mock_client = MagicMock() + + with mock_client_context(mixin, mock_client): + # Test chat completion with disallowed model + with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + + # Test completion with disallowed model + with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") + ) + + # Test embeddings with disallowed model + with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") + ) + + mock_client.chat.completions.create.assert_not_called() + mock_client.completions.create.assert_not_called() + mock_client.embeddings.create.assert_not_called() + + async def test_inference_with_no_restrictions(self, mixin, mock_client_context): + """Test that inference succeeds when allowed_models is None or empty list blocks all""" + # Test with None (no restrictions) + assert mixin.config.allowed_models is None + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test with empty list (blocks all models) + mixin.config.allowed_models = [] + with mock_client_context(mixin, mock_client): + with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + )