mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore: add provider-data-api-key support to openaimixin
This commit is contained in:
parent
f7c5ef4ec0
commit
2be6ec2ad8
5 changed files with 116 additions and 17 deletions
|
@ -59,7 +59,7 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
provider_data_api_key_field: str | None = None,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
|
@ -70,7 +70,7 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
:param model_entries: The model entries to register.
|
||||
:param api_key_from_config: The API key to use from the config.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
|
||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
|
|
|
@ -63,7 +63,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
allowed_models: list[str] | None = None,
|
||||
):
|
||||
self.allowed_models = allowed_models
|
||||
self.allowed_models = allowed_models if allowed_models else []
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
@ -32,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import localize_image_
|
|||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
@ -69,6 +70,9 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -111,9 +115,28 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
|||
|
||||
Uses the abstract methods get_api_key() and get_base_url() which must be
|
||||
implemented by child classes.
|
||||
|
||||
Users can also provide the API key via the provider data header, which
|
||||
is used instead of any config API key.
|
||||
"""
|
||||
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
api_key=api_key,
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||
config = OpenAIConfig(api_key="test-key")
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||
|
||||
|
@ -27,6 +28,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://custom.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
|
@ -38,6 +40,7 @@ class TestOpenAIBaseURLConfig:
|
|||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||
|
||||
|
@ -47,6 +50,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://config.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Config should take precedence over environment variable
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
@ -57,6 +61,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
@ -76,6 +81,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
@ -117,6 +123,7 @@ class TestOpenAIBaseURLConfig:
|
|||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
|
|
@ -4,18 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
def __init__(self):
|
||||
self.__provider_id__ = "test-provider"
|
||||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
@ -24,7 +26,7 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
"""Test implementation with embedding model metadata"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
|
@ -32,14 +34,6 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
|||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
|
@ -366,3 +360,78 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
assert not await mixin.check_model_availability("some-mock-model-id")
|
||||
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||
|
||||
|
||||
class ProviderDataValidator(BaseModel):
|
||||
"""Validator for provider data in tests"""
|
||||
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class OpenAIMixinWithProviderData(OpenAIMixinImpl):
|
||||
"""Test implementation that supports provider data API key field"""
|
||||
|
||||
provider_data_api_key_field: str = "test_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return "default-api-key"
|
||||
|
||||
def get_base_url(self):
|
||||
return "default-base-url"
|
||||
|
||||
|
||||
class TestOpenAIMixinProviderDataApiKey:
|
||||
"""Test cases for provider_data_api_key_field functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field(self):
|
||||
"""Mixin instance with provider_data_api_key_field set"""
|
||||
mixin_instance = OpenAIMixinWithProviderData()
|
||||
|
||||
# Mock provider_spec for provider data validation
|
||||
mock_provider_spec = MagicMock()
|
||||
mock_provider_spec.provider_type = "test-provider-with-data"
|
||||
mock_provider_spec.provider_data_validator = (
|
||||
"tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator"
|
||||
)
|
||||
mixin_instance.__provider_spec__ = mock_provider_spec
|
||||
|
||||
return mixin_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field):
|
||||
mixin_with_provider_data_field.get_api_key = Mock(return_value=None)
|
||||
return mixin_with_provider_data_field
|
||||
|
||||
def test_no_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that client uses config API key when no provider data is available"""
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_with_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that provider data API key overrides config API key"""
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
||||
):
|
||||
assert mixin_with_provider_data_field.client.api_key == "provider-data-key"
|
||||
|
||||
def test_with_wrong_key(self, mixin_with_provider_data_field):
|
||||
"""Test fallback to config when provider data doesn't have the required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_error_when_no_config_and_provider_data_has_wrong_key(
|
||||
self, mixin_with_provider_data_field_and_none_api_key
|
||||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
"""Test that error message includes correct field name and header information"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue