mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: add provider-data-api-key support to openaimixin
This commit is contained in:
parent
f7c5ef4ec0
commit
2be6ec2ad8
5 changed files with 116 additions and 17 deletions
|
@ -19,6 +19,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||
config = OpenAIConfig(api_key="test-key")
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||
|
||||
|
@ -27,6 +28,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://custom.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
|
@ -38,6 +40,7 @@ class TestOpenAIBaseURLConfig:
|
|||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||
|
||||
|
@ -47,6 +50,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://config.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Config should take precedence over environment variable
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
@ -57,6 +61,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
@ -76,6 +81,7 @@ class TestOpenAIBaseURLConfig:
|
|||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
@ -117,6 +123,7 @@ class TestOpenAIBaseURLConfig:
|
|||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
adapter.provider_data_api_key_field = None # Disable provider data for this test
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
|
|
@ -4,18 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
def __init__(self):
|
||||
self.__provider_id__ = "test-provider"
|
||||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
@ -24,7 +26,7 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
"""Test implementation with embedding model metadata"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
|
@ -32,14 +34,6 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
|||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
|
@ -366,3 +360,78 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
assert not await mixin.check_model_availability("some-mock-model-id")
|
||||
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||
|
||||
|
||||
class ProviderDataValidator(BaseModel):
|
||||
"""Validator for provider data in tests"""
|
||||
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class OpenAIMixinWithProviderData(OpenAIMixinImpl):
|
||||
"""Test implementation that supports provider data API key field"""
|
||||
|
||||
provider_data_api_key_field: str = "test_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return "default-api-key"
|
||||
|
||||
def get_base_url(self):
|
||||
return "default-base-url"
|
||||
|
||||
|
||||
class TestOpenAIMixinProviderDataApiKey:
|
||||
"""Test cases for provider_data_api_key_field functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field(self):
|
||||
"""Mixin instance with provider_data_api_key_field set"""
|
||||
mixin_instance = OpenAIMixinWithProviderData()
|
||||
|
||||
# Mock provider_spec for provider data validation
|
||||
mock_provider_spec = MagicMock()
|
||||
mock_provider_spec.provider_type = "test-provider-with-data"
|
||||
mock_provider_spec.provider_data_validator = (
|
||||
"tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator"
|
||||
)
|
||||
mixin_instance.__provider_spec__ = mock_provider_spec
|
||||
|
||||
return mixin_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field):
|
||||
mixin_with_provider_data_field.get_api_key = Mock(return_value=None)
|
||||
return mixin_with_provider_data_field
|
||||
|
||||
def test_no_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that client uses config API key when no provider data is available"""
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_with_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that provider data API key overrides config API key"""
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
||||
):
|
||||
assert mixin_with_provider_data_field.client.api_key == "provider-data-key"
|
||||
|
||||
def test_with_wrong_key(self, mixin_with_provider_data_field):
|
||||
"""Test fallback to config when provider data doesn't have the required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_error_when_no_config_and_provider_data_has_wrong_key(
|
||||
self, mixin_with_provider_data_field_and_none_api_key
|
||||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
"""Test that error message includes correct field name and header information"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue