diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index abba62f87..e9a41fcf3 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -254,6 +254,12 @@ class LiteLLMOpenAIMixin( api_key = getattr(provider_data, key_field) else: api_key = self.api_key_from_config + if not api_key: + raise ValueError( + "API key is not set. Please provide a valid API key in the " + "provider data header, e.g. x-llamastack-provider-data: " + f'{{"{key_field}": ""}}, or in the provider config.' + ) return api_key async def embeddings( diff --git a/tests/unit/providers/inference/test_litellm_openai_mixin.py b/tests/unit/providers/inference/test_litellm_openai_mixin.py new file mode 100644 index 000000000..bbc437edf --- /dev/null +++ b/tests/unit/providers/inference/test_litellm_openai_mixin.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + + +# Test fixtures and helper classes +class TestConfig(BaseModel): + api_key: str | None = Field(default=None) + + +class TestProviderDataValidator(BaseModel): + test_api_key: str | None = Field(default=None) + + +class TestLiteLLMAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: TestConfig): + super().__init__( + model_entries=[], + litellm_provider_name="test", + api_key_from_config=config.api_key, + provider_data_api_key_field="test_api_key", + openai_compat_api_base=None, + ) + + +@pytest.fixture +def adapter_with_config_key(): + """Fixture to create adapter with API key in config""" + config = TestConfig(api_key="config-api-key") + adapter = TestLiteLLMAdapter(config) + adapter.__provider_spec__ = MagicMock() + adapter.__provider_spec__.provider_data_validator = ( + "tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator" + ) + return adapter + + +@pytest.fixture +def adapter_without_config_key(): + """Fixture to create adapter without API key in config""" + config = TestConfig(api_key=None) + adapter = TestLiteLLMAdapter(config) + adapter.__provider_spec__ = MagicMock() + adapter.__provider_spec__.provider_data_validator = ( + "tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator" + ) + return adapter + + +def test_api_key_from_config_when_no_provider_data(adapter_with_config_key): + """Test that adapter uses config API key when no provider data is available""" + api_key = adapter_with_config_key.get_api_key() + assert api_key == "config-api-key" + + +def test_provider_data_takes_priority_over_config(adapter_with_config_key): + """Test that provider data API key overrides config API key""" + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})} + ): + api_key = adapter_with_config_key.get_api_key() + assert api_key == "provider-data-key" + + +def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key): + """Test fallback to config when provider data doesn't have the required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + api_key = adapter_with_config_key.get_api_key() + assert api_key == "config-api-key" + + +def test_error_when_no_api_key_available(adapter_without_config_key): + """Test that ValueError is raised when neither config nor provider data have API key""" + with pytest.raises(ValueError, match="API key is not set"): + adapter_without_config_key.get_api_key() + + +def test_error_when_provider_data_has_wrong_key(adapter_without_config_key): + """Test that ValueError is raised when provider data exists but doesn't have required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + with pytest.raises(ValueError, match="API key is not set"): + adapter_without_config_key.get_api_key() + + +def test_provider_data_works_when_config_is_none(adapter_without_config_key): + """Test that provider data works even when config has no API key""" + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})} + ): + api_key = adapter_without_config_key.get_api_key() + assert api_key == "provider-only-key" + + +def test_error_message_includes_correct_field_names(adapter_without_config_key): + """Test that error message includes correct field name and header information""" + try: + adapter_without_config_key.get_api_key() + raise AssertionError("Should have raised ValueError") + except ValueError as e: + assert "test_api_key" in str(e) # Should mention the correct field name + assert "x-llamastack-provider-data" in str(e) # Should mention header name