From 2be6ec2ad86166a20ab66bbf590483913bb62814 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 1 Oct 2025 14:14:24 -0400 Subject: [PATCH] chore: add provider-data-api-key support to openaimixin --- .../utils/inference/litellm_openai_mixin.py | 4 +- .../utils/inference/model_registry.py | 2 +- .../providers/utils/inference/openai_mixin.py | 27 +++++- .../inference/test_openai_base_url_config.py | 7 ++ .../utils/inference/test_openai_mixin.py | 93 ++++++++++++++++--- 5 files changed, 116 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 23a72bb3a..c8d3bddc7 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -59,7 +59,7 @@ class LiteLLMOpenAIMixin( self, litellm_provider_name: str, api_key_from_config: str | None, - provider_data_api_key_field: str, + provider_data_api_key_field: str | None = None, model_entries: list[ProviderModelEntry] | None = None, openai_compat_api_base: str | None = None, download_images: bool = False, @@ -70,7 +70,7 @@ class LiteLLMOpenAIMixin( :param model_entries: The model entries to register. :param api_key_from_config: The API key to use from the config. - :param provider_data_api_key_field: The field in the provider data that contains the API key. + :param provider_data_api_key_field: The field in the provider data that contains the API key (optional). :param litellm_provider_name: The name of the provider, used for model lookups. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. :param download_images: Whether to download images and convert to base64 for message conversion. diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 746ebd8f6..4913c2e1f 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -63,7 +63,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): model_entries: list[ProviderModelEntry] | None = None, allowed_models: list[str] | None = None, ): - self.allowed_models = allowed_models + self.allowed_models = allowed_models if allowed_models else [] self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 7da97e6b1..becec5fb3 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( OpenAIResponseFormatParam, ) from llama_stack.apis.models import ModelType +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params @@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_ logger = get_logger(name=__name__, category="providers::utils") -class OpenAIMixin(ModelRegistryHelper, ABC): +class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC): """ Mixin class that provides OpenAI-specific functionality for inference providers. This class handles direct OpenAI API calls using the AsyncOpenAI client. @@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC): # List of allowed models for this provider, if empty all models allowed allowed_models: list[str] = [] + # Optional field name in provider data to look for API key, which takes precedence + provider_data_api_key_field: str | None = None + @abstractmethod def get_api_key(self) -> str: """ @@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC): Uses the abstract methods get_api_key() and get_base_url() which must be implemented by child classes. + + Users can also provide the API key via the provider data header, which + is used instead of any config API key. """ + + api_key = self.get_api_key() + + if self.provider_data_api_key_field: + provider_data = self.get_request_provider_data() + if provider_data and getattr(provider_data, self.provider_data_api_key_field, None): + api_key = getattr(provider_data, self.provider_data_api_key_field) + + if not api_key: # TODO: let get_api_key return None + raise ValueError( + "API key is not set. Please provide a valid API key in the " + "provider data header, e.g. x-llamastack-provider-data: " + f'{{"{self.provider_data_api_key_field}": ""}}, ' + "or in the provider config." + ) + return AsyncOpenAI( - api_key=self.get_api_key(), + api_key=api_key, base_url=self.get_base_url(), **self.get_extra_client_params(), ) diff --git a/tests/unit/providers/inference/test_openai_base_url_config.py b/tests/unit/providers/inference/test_openai_base_url_config.py index 903772f0c..7c5a5b327 100644 --- a/tests/unit/providers/inference/test_openai_base_url_config.py +++ b/tests/unit/providers/inference/test_openai_base_url_config.py @@ -19,6 +19,7 @@ class TestOpenAIBaseURLConfig: """Test that the adapter uses the default OpenAI base URL when no environment variable is set.""" config = OpenAIConfig(api_key="test-key") adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test assert adapter.get_base_url() == "https://api.openai.com/v1" @@ -27,6 +28,7 @@ class TestOpenAIBaseURLConfig: custom_url = "https://custom.openai.com/v1" config = OpenAIConfig(api_key="test-key", base_url=custom_url) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test assert adapter.get_base_url() == custom_url @@ -38,6 +40,7 @@ class TestOpenAIBaseURLConfig: processed_config = replace_env_vars(config_data) config = OpenAIConfig.model_validate(processed_config) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test assert adapter.get_base_url() == "https://env.openai.com/v1" @@ -47,6 +50,7 @@ class TestOpenAIBaseURLConfig: custom_url = "https://config.openai.com/v1" config = OpenAIConfig(api_key="test-key", base_url=custom_url) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test # Config should take precedence over environment variable assert adapter.get_base_url() == custom_url @@ -57,6 +61,7 @@ class TestOpenAIBaseURLConfig: custom_url = "https://test.openai.com/v1" config = OpenAIConfig(api_key="test-key", base_url=custom_url) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test # Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin adapter.get_api_key = MagicMock(return_value="test-key") @@ -76,6 +81,7 @@ class TestOpenAIBaseURLConfig: custom_url = "https://test.openai.com/v1" config = OpenAIConfig(api_key="test-key", base_url=custom_url) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test # Mock the get_api_key method adapter.get_api_key = MagicMock(return_value="test-key") @@ -117,6 +123,7 @@ class TestOpenAIBaseURLConfig: processed_config = replace_env_vars(config_data) config = OpenAIConfig.model_validate(processed_config) adapter = OpenAIInferenceAdapter(config) + adapter.provider_data_api_key_field = None # Disable provider data for this test # Mock the get_api_key method adapter.get_api_key = MagicMock(return_value="test-key") diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index b55f206b9..8ef7ec81c 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -4,18 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch +import json +from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest +from pydantic import BaseModel, Field from llama_stack.apis.inference import Model, OpenAIUserMessageParam from llama_stack.apis.models import ModelType +from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin class OpenAIMixinImpl(OpenAIMixin): - def __init__(self): - self.__provider_id__ = "test-provider" + __provider_id__: str = "test-provider" def get_api_key(self) -> str: raise NotImplementedError("This method should be mocked in tests") @@ -24,7 +26,7 @@ class OpenAIMixinImpl(OpenAIMixin): raise NotImplementedError("This method should be mocked in tests") -class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin): +class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): """Test implementation with embedding model metadata""" embedding_model_metadata = { @@ -32,14 +34,6 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin): "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, } - __provider_id__ = "test-provider" - - def get_api_key(self) -> str: - raise NotImplementedError("This method should be mocked in tests") - - def get_base_url(self) -> str: - raise NotImplementedError("This method should be mocked in tests") - @pytest.fixture def mixin(): @@ -366,3 +360,78 @@ class TestOpenAIMixinAllowedModels: assert await mixin.check_model_availability("final-mock-model-id") assert not await mixin.check_model_availability("some-mock-model-id") assert not await mixin.check_model_availability("another-mock-model-id") + + +class ProviderDataValidator(BaseModel): + """Validator for provider data in tests""" + + test_api_key: str | None = Field(default=None) + + +class OpenAIMixinWithProviderData(OpenAIMixinImpl): + """Test implementation that supports provider data API key field""" + + provider_data_api_key_field: str = "test_api_key" + + def get_api_key(self) -> str: + return "default-api-key" + + def get_base_url(self): + return "default-base-url" + + +class TestOpenAIMixinProviderDataApiKey: + """Test cases for provider_data_api_key_field functionality""" + + @pytest.fixture + def mixin_with_provider_data_field(self): + """Mixin instance with provider_data_api_key_field set""" + mixin_instance = OpenAIMixinWithProviderData() + + # Mock provider_spec for provider data validation + mock_provider_spec = MagicMock() + mock_provider_spec.provider_type = "test-provider-with-data" + mock_provider_spec.provider_data_validator = ( + "tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator" + ) + mixin_instance.__provider_spec__ = mock_provider_spec + + return mixin_instance + + @pytest.fixture + def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field): + mixin_with_provider_data_field.get_api_key = Mock(return_value=None) + return mixin_with_provider_data_field + + def test_no_provider_data(self, mixin_with_provider_data_field): + """Test that client uses config API key when no provider data is available""" + assert mixin_with_provider_data_field.client.api_key == "default-api-key" + + def test_with_provider_data(self, mixin_with_provider_data_field): + """Test that provider data API key overrides config API key""" + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})} + ): + assert mixin_with_provider_data_field.client.api_key == "provider-data-key" + + def test_with_wrong_key(self, mixin_with_provider_data_field): + """Test fallback to config when provider data doesn't have the required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + assert mixin_with_provider_data_field.client.api_key == "default-api-key" + + def test_error_when_no_config_and_provider_data_has_wrong_key( + self, mixin_with_provider_data_field_and_none_api_key + ): + """Test that ValueError is raised when provider data exists but doesn't have required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + with pytest.raises(ValueError, match="API key is not set"): + _ = mixin_with_provider_data_field_and_none_api_key.client + + def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key): + """Test that error message includes correct field name and header information""" + with pytest.raises(ValueError) as exc_info: + _ = mixin_with_provider_data_field_and_none_api_key.client + + error_message = str(exc_info.value) + assert "test_api_key" in error_message + assert "x-llamastack-provider-data" in error_message