This commit is contained in:
Sébastien Han 2025-10-01 15:47:54 +02:00 committed by GitHub
commit 79ced0c85b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 341 additions and 209 deletions

View file

@ -7,6 +7,7 @@
import boto3
import pytest
from moto import mock_aws
from pydantic import SecretStr
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -43,6 +44,7 @@ def s3_config(tmp_path):
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
aws_secret_access_key=SecretStr("fake"),
)

View file

@ -17,7 +17,7 @@ class TestBedrockBaseConfig:
# Basic creds should be None
assert config.aws_access_key_id is None
assert config.aws_secret_access_key is None
assert not config.aws_secret_access_key
assert config.region_name is None
# Timeouts get defaults
@ -39,7 +39,7 @@ class TestBedrockBaseConfig:
config = BedrockBaseConfig()
assert config.aws_access_key_id == "AKIATEST123"
assert config.aws_secret_access_key == "secret123"
assert config.aws_secret_access_key.get_secret_value() == "secret123"
assert config.region_name == "us-west-2"
assert config.total_max_attempts == 5
assert config.retry_mode == "adaptive"

View file

@ -7,6 +7,8 @@
import json
from unittest.mock import MagicMock
from pydantic import SecretStr
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
@ -21,7 +23,7 @@ from llama_stack.providers.remote.inference.together.together import TogetherInf
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
config = GroqConfig(api_key=SecretStr(""))
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
@ -39,7 +41,7 @@ def test_groq_provider_openai_client_caching():
def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
config = OpenAIConfig(api_key=SecretStr(""))
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
@ -58,7 +60,7 @@ def test_openai_provider_openai_client_caching():
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
config = TogetherImplConfig(api_key=SecretStr(""))
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
@ -76,7 +78,7 @@ def test_together_provider_openai_client_caching():
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
config = LlamaCompatConfig(api_key=SecretStr(""))
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()

View file

@ -8,7 +8,7 @@ import json
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -16,11 +16,11 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
# Test fixtures and helper classes
class TestConfig(BaseModel):
api_key: str | None = Field(default=None)
api_key: SecretStr | None = Field(default=None)
class TestProviderDataValidator(BaseModel):
test_api_key: str | None = Field(default=None)
test_api_key: SecretStr | None = Field(default=None)
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
@ -36,7 +36,7 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
@pytest.fixture
def adapter_with_config_key():
"""Fixture to create adapter with API key in config"""
config = TestConfig(api_key="config-api-key")
config = TestConfig(api_key=SecretStr("config-api-key"))
adapter = TestLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
@ -59,7 +59,7 @@ def adapter_without_config_key():
def test_api_key_from_config_when_no_provider_data(adapter_with_config_key):
"""Test that adapter uses config API key when no provider data is available"""
api_key = adapter_with_config_key.get_api_key()
api_key = adapter_with_config_key.get_api_key().get_secret_value()
assert api_key == "config-api-key"
@ -68,28 +68,28 @@ def test_provider_data_takes_priority_over_config(adapter_with_config_key):
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
):
api_key = adapter_with_config_key.get_api_key()
api_key = adapter_with_config_key.get_api_key().get_secret_value()
assert api_key == "provider-data-key"
def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key):
"""Test fallback to config when provider data doesn't have the required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
api_key = adapter_with_config_key.get_api_key()
api_key = adapter_with_config_key.get_api_key().get_secret_value()
assert api_key == "config-api-key"
def test_error_when_no_api_key_available(adapter_without_config_key):
"""Test that ValueError is raised when neither config nor provider data have API key"""
with pytest.raises(ValueError, match="API key is not set"):
adapter_without_config_key.get_api_key()
adapter_without_config_key.get_api_key().get_secret_value()
def test_error_when_provider_data_has_wrong_key(adapter_without_config_key):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
adapter_without_config_key.get_api_key()
adapter_without_config_key.get_api_key().get_secret_value()
def test_provider_data_works_when_config_is_none(adapter_without_config_key):
@ -97,14 +97,14 @@ def test_provider_data_works_when_config_is_none(adapter_without_config_key):
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})}
):
api_key = adapter_without_config_key.get_api_key()
api_key = adapter_without_config_key.get_api_key().get_secret_value()
assert api_key == "provider-only-key"
def test_error_message_includes_correct_field_names(adapter_without_config_key):
"""Test that error message includes correct field name and header information"""
try:
adapter_without_config_key.get_api_key()
adapter_without_config_key.get_api_key().get_secret_value()
raise AssertionError("Should have raised ValueError")
except ValueError as e:
assert "test_api_key" in str(e) # Should mention the correct field name

View file

@ -7,6 +7,8 @@
import os
from unittest.mock import MagicMock, patch
from pydantic import SecretStr
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
@ -59,14 +61,14 @@ class TestOpenAIBaseURLConfig:
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
# Access the client property to trigger AsyncOpenAI initialization
_ = adapter.client
# Verify AsyncOpenAI was called with the correct base_url
mock_openai_class.assert_called_once_with(
api_key="test-key",
api_key=SecretStr("test-key").get_secret_value(),
base_url=custom_url,
)
@ -78,7 +80,7 @@ class TestOpenAIBaseURLConfig:
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
@ -101,7 +103,7 @@ class TestOpenAIBaseURLConfig:
# Verify the client was created with the custom URL
mock_openai_class.assert_called_with(
api_key="test-key",
api_key=SecretStr("test-key").get_secret_value(),
base_url=custom_url,
)
@ -119,7 +121,7 @@ class TestOpenAIBaseURLConfig:
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
# Mock a model object that will be returned by models.list()
mock_model = MagicMock()
@ -142,6 +144,6 @@ class TestOpenAIBaseURLConfig:
# Verify the client was created with the environment variable URL
mock_openai_class.assert_called_with(
api_key="test-key",
api_key=SecretStr("test-key").get_secret_value(),
base_url="https://proxy.openai.com/v1",
)

View file

@ -26,6 +26,7 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from pydantic import SecretStr
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -688,31 +689,35 @@ async def test_should_refresh_models():
"""
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr(""), refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr(""), refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr("fake"), refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
config4 = VLLMInferenceAdapterConfig(
url="http://test.localhost", api_token=SecretStr("real-token-123"), refresh_models=True
)
adapter4 = VLLMInferenceAdapter(config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
config5 = VLLMInferenceAdapterConfig(
url="http://test.localhost", api_token=SecretStr("real-token-456"), refresh_models=False
)
adapter5 = VLLMInferenceAdapter(config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
@ -735,7 +740,7 @@ async def test_provider_data_var_context_propagation(vllm_inference_adapter):
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_api_token = SecretStr("test-token-123")
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data

View file

@ -9,6 +9,7 @@ import warnings
from unittest.mock import patch
import pytest
from pydantic import SecretStr
from llama_stack.apis.post_training.post_training import (
DataConfig,
@ -32,7 +33,7 @@ class TestNvidiaParameters:
"""Setup and teardown for each test method."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=SecretStr(""))
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(

View file

@ -9,6 +9,7 @@ import warnings
from unittest.mock import patch
import pytest
from pydantic import SecretStr
from llama_stack.apis.post_training.post_training import (
DataConfig,
@ -34,7 +35,7 @@ def nvidia_post_training_adapter():
"""Fixture to create and configure the NVIDIA post training adapter."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=SecretStr(""))
adapter = NvidiaPostTrainingAdapter(config)
with patch.object(adapter, "_make_request") as mock_make_request: