mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: add provider-data-api-key support to openaimixin (#3639)
# What does this PR do? the LiteLLMOpenAIMixin provides support for reading key from provider data (headers users send). this adds the same functionality to the OpenAIMixin. this is infrastructure for migrating providers. ## Test Plan ci w/ new tests
This commit is contained in:
parent
28bbbcf2c1
commit
4dbe0593f9
5 changed files with 116 additions and 17 deletions
|
@ -4,18 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
def __init__(self):
|
||||
self.__provider_id__ = "test-provider"
|
||||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
@ -24,7 +26,7 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
"""Test implementation with embedding model metadata"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
|
@ -32,14 +34,6 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
|||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
|
@ -366,3 +360,78 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
assert not await mixin.check_model_availability("some-mock-model-id")
|
||||
assert not await mixin.check_model_availability("another-mock-model-id")
|
||||
|
||||
|
||||
class ProviderDataValidator(BaseModel):
|
||||
"""Validator for provider data in tests"""
|
||||
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class OpenAIMixinWithProviderData(OpenAIMixinImpl):
|
||||
"""Test implementation that supports provider data API key field"""
|
||||
|
||||
provider_data_api_key_field: str = "test_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return "default-api-key"
|
||||
|
||||
def get_base_url(self):
|
||||
return "default-base-url"
|
||||
|
||||
|
||||
class TestOpenAIMixinProviderDataApiKey:
|
||||
"""Test cases for provider_data_api_key_field functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field(self):
|
||||
"""Mixin instance with provider_data_api_key_field set"""
|
||||
mixin_instance = OpenAIMixinWithProviderData()
|
||||
|
||||
# Mock provider_spec for provider data validation
|
||||
mock_provider_spec = MagicMock()
|
||||
mock_provider_spec.provider_type = "test-provider-with-data"
|
||||
mock_provider_spec.provider_data_validator = (
|
||||
"tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator"
|
||||
)
|
||||
mixin_instance.__provider_spec__ = mock_provider_spec
|
||||
|
||||
return mixin_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field):
|
||||
mixin_with_provider_data_field.get_api_key = Mock(return_value=None)
|
||||
return mixin_with_provider_data_field
|
||||
|
||||
def test_no_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that client uses config API key when no provider data is available"""
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_with_provider_data(self, mixin_with_provider_data_field):
|
||||
"""Test that provider data API key overrides config API key"""
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
||||
):
|
||||
assert mixin_with_provider_data_field.client.api_key == "provider-data-key"
|
||||
|
||||
def test_with_wrong_key(self, mixin_with_provider_data_field):
|
||||
"""Test fallback to config when provider data doesn't have the required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
assert mixin_with_provider_data_field.client.api_key == "default-api-key"
|
||||
|
||||
def test_error_when_no_config_and_provider_data_has_wrong_key(
|
||||
self, mixin_with_provider_data_field_and_none_api_key
|
||||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
"""Test that error message includes correct field name and header information"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue