mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-12 05:54:38 +00:00
revert: do not use MySecretStr
We don't need this if we can set it to empty string. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
bc64635835
commit
2a34226727
86 changed files with 208 additions and 263 deletions
|
@ -7,6 +7,7 @@
|
|||
import boto3
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
@ -43,6 +44,7 @@ def s3_config(tmp_path):
|
|||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
aws_secret_access_key=SecretStr("fake"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ class TestBedrockBaseConfig:
|
|||
|
||||
# Basic creds should be None
|
||||
assert config.aws_access_key_id is None
|
||||
assert config.aws_secret_access_key is None
|
||||
assert not config.aws_secret_access_key
|
||||
assert config.region_name is None
|
||||
|
||||
# Timeouts get defaults
|
||||
|
@ -39,7 +39,7 @@ class TestBedrockBaseConfig:
|
|||
config = BedrockBaseConfig()
|
||||
|
||||
assert config.aws_access_key_id == "AKIATEST123"
|
||||
assert config.aws_secret_access_key == "secret123"
|
||||
assert config.aws_secret_access_key.get_secret_value() == "secret123"
|
||||
assert config.region_name == "us-west-2"
|
||||
assert config.total_max_attempts == 5
|
||||
assert config.retry_mode == "adaptive"
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
|
@ -21,7 +23,7 @@ from llama_stack.providers.remote.inference.together.together import TogetherInf
|
|||
def test_groq_provider_openai_client_caching():
|
||||
"""Ensure the Groq provider does not cache api keys across client requests"""
|
||||
|
||||
config = GroqConfig()
|
||||
config = GroqConfig(api_key=SecretStr(""))
|
||||
inference_adapter = GroqInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
|
@ -33,13 +35,13 @@ def test_groq_provider_openai_client_caching():
|
|||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.client.api_key.get_secret_value() == api_key
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
def test_openai_provider_openai_client_caching():
|
||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||
|
||||
config = OpenAIConfig()
|
||||
config = OpenAIConfig(api_key=SecretStr(""))
|
||||
inference_adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
|
@ -52,13 +54,13 @@ def test_openai_provider_openai_client_caching():
|
|||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
openai_client = inference_adapter.client
|
||||
assert openai_client.api_key.get_secret_value() == api_key
|
||||
assert openai_client.api_key == api_key
|
||||
|
||||
|
||||
def test_together_provider_openai_client_caching():
|
||||
"""Ensure the Together provider does not cache api keys across client requests"""
|
||||
|
||||
config = TogetherImplConfig()
|
||||
config = TogetherImplConfig(api_key=SecretStr(""))
|
||||
inference_adapter = TogetherInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
|
@ -76,7 +78,7 @@ def test_together_provider_openai_client_caching():
|
|||
|
||||
def test_llama_compat_provider_openai_client_caching():
|
||||
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
|
||||
config = LlamaCompatConfig()
|
||||
config = LlamaCompatConfig(api_key=SecretStr(""))
|
||||
inference_adapter = LlamaCompatInferenceAdapter(config)
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
|
@ -86,4 +88,4 @@ def test_llama_compat_provider_openai_client_caching():
|
|||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
||||
assert inference_adapter.client.api_key.get_secret_value() == api_key
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
|
|
@ -8,20 +8,19 @@ import json
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.core.secret_types import MySecretStr
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
|
||||
# Test fixtures and helper classes
|
||||
class TestConfig(BaseModel):
|
||||
api_key: MySecretStr | None = Field(default=None)
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
|
||||
|
||||
class TestProviderDataValidator(BaseModel):
|
||||
test_api_key: MySecretStr | None = Field(default=None)
|
||||
test_api_key: SecretStr | None = Field(default=None)
|
||||
|
||||
|
||||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
|
@ -37,7 +36,7 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
|||
@pytest.fixture
|
||||
def adapter_with_config_key():
|
||||
"""Fixture to create adapter with API key in config"""
|
||||
config = TestConfig(api_key=MySecretStr("config-api-key"))
|
||||
config = TestConfig(api_key=SecretStr("config-api-key"))
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
|
|
|
@ -7,14 +7,9 @@
|
|||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from llama_stack.core.secret_types import MySecretStr
|
||||
|
||||
|
||||
# Wrapper for backward compatibility in tests
|
||||
def replace_env_vars_compat(config, path=""):
|
||||
return replace_env_vars_compat(config, path, None, None)
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
|
||||
|
@ -42,7 +37,7 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars_compat(config_data)
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
|
@ -66,14 +61,14 @@ class TestOpenAIBaseURLConfig:
|
|||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
adapter.get_api_key = MagicMock(return_value=MySecretStr("test-key"))
|
||||
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||
|
||||
# Access the client property to trigger AsyncOpenAI initialization
|
||||
_ = adapter.client
|
||||
|
||||
# Verify AsyncOpenAI was called with the correct base_url
|
||||
mock_openai_class.assert_called_once_with(
|
||||
api_key=MySecretStr("test-key"),
|
||||
api_key=SecretStr("test-key").get_secret_value(),
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
|
@ -85,7 +80,7 @@ class TestOpenAIBaseURLConfig:
|
|||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value=MySecretStr("test-key"))
|
||||
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
|
@ -108,7 +103,7 @@ class TestOpenAIBaseURLConfig:
|
|||
|
||||
# Verify the client was created with the custom URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key=MySecretStr("test-key"),
|
||||
api_key=SecretStr("test-key").get_secret_value(),
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
|
@ -121,12 +116,12 @@ class TestOpenAIBaseURLConfig:
|
|||
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars_compat(config_data)
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value=MySecretStr("test-key"))
|
||||
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
|
@ -149,6 +144,6 @@ class TestOpenAIBaseURLConfig:
|
|||
|
||||
# Verify the client was created with the environment variable URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key=MySecretStr("test-key"),
|
||||
api_key=SecretStr("test-key").get_secret_value(),
|
||||
base_url="https://proxy.openai.com/v1",
|
||||
)
|
||||
|
|
|
@ -26,6 +26,7 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -688,31 +689,35 @@ async def test_should_refresh_models():
|
|||
"""
|
||||
|
||||
# Test case 1: refresh_models is True, api_token is None
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr(""), refresh_models=True)
|
||||
adapter1 = VLLMInferenceAdapter(config1)
|
||||
result1 = await adapter1.should_refresh_models()
|
||||
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 2: refresh_models is True, api_token is empty string
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr(""), refresh_models=True)
|
||||
adapter2 = VLLMInferenceAdapter(config2)
|
||||
result2 = await adapter2.should_refresh_models()
|
||||
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 3: refresh_models is True, api_token is "fake" (default)
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=SecretStr("fake"), refresh_models=True)
|
||||
adapter3 = VLLMInferenceAdapter(config3)
|
||||
result3 = await adapter3.should_refresh_models()
|
||||
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 4: refresh_models is True, api_token is real token
|
||||
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||
config4 = VLLMInferenceAdapterConfig(
|
||||
url="http://test.localhost", api_token=SecretStr("real-token-123"), refresh_models=True
|
||||
)
|
||||
adapter4 = VLLMInferenceAdapter(config4)
|
||||
result4 = await adapter4.should_refresh_models()
|
||||
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 5: refresh_models is False, api_token is real token
|
||||
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
|
||||
config5 = VLLMInferenceAdapterConfig(
|
||||
url="http://test.localhost", api_token=SecretStr("real-token-456"), refresh_models=False
|
||||
)
|
||||
adapter5 = VLLMInferenceAdapter(config5)
|
||||
result5 = await adapter5.should_refresh_models()
|
||||
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
|
||||
|
@ -735,7 +740,7 @@ async def test_provider_data_var_context_propagation(vllm_inference_adapter):
|
|||
|
||||
# Mock provider data to return test data
|
||||
mock_provider_data = MagicMock()
|
||||
mock_provider_data.vllm_api_token = "test-token-123"
|
||||
mock_provider_data.vllm_api_token = SecretStr("test-token-123")
|
||||
mock_provider_data.vllm_url = "http://test-server:8000/v1"
|
||||
mock_get_provider_data.return_value = mock_provider_data
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import warnings
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
|
@ -32,7 +33,7 @@ class TestNvidiaParameters:
|
|||
"""Setup and teardown for each test method."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=SecretStr(""))
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
|
|
|
@ -9,6 +9,7 @@ import warnings
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
|
@ -34,7 +35,7 @@ def nvidia_post_training_adapter():
|
|||
"""Fixture to create and configure the NVIDIA post training adapter."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=SecretStr(""))
|
||||
adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
with patch.object(adapter, "_make_request") as mock_make_request:
|
||||
|
|
|
@ -8,10 +8,7 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
|
||||
# Wrapper for backward compatibility in tests
|
||||
def replace_env_vars_compat(config, path=""):
|
||||
return replace_env_vars_compat(config, path, None, None)
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -35,54 +32,52 @@ def setup_env_vars():
|
|||
|
||||
|
||||
def test_simple_replacement(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.TEST_VAR}") == "test_value"
|
||||
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_not_set(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.NOT_SET:=default}") == "default"
|
||||
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
|
||||
|
||||
|
||||
def test_default_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.TEST_VAR:=default}") == "test_value"
|
||||
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.EMPTY_VAR:=default}") == "default"
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
|
||||
|
||||
|
||||
def test_none_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.EMPTY_VAR:=}") is None
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
|
||||
|
||||
|
||||
def test_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.TEST_VAR:=}") == "test_value"
|
||||
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
|
||||
|
||||
|
||||
def test_empty_var_no_default(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
|
||||
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.TEST_VAR:+conditional}") == "conditional"
|
||||
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_conditional_value_when_not_set(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.NOT_SET:+conditional}") is None
|
||||
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.EMPTY_VAR:+conditional}") is None
|
||||
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_with_zero(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.ZERO_VAR:+conditional}") == "conditional"
|
||||
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_mixed_syntax(setup_env_vars):
|
||||
assert replace_env_vars_compat("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
|
||||
assert (
|
||||
replace_env_vars_compat("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
|
||||
)
|
||||
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
|
||||
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
|
||||
|
||||
|
||||
def test_nested_structures(setup_env_vars):
|
||||
|
@ -92,11 +87,11 @@ def test_nested_structures(setup_env_vars):
|
|||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
assert replace_env_vars_compat(data) == expected
|
||||
assert replace_env_vars(data) == expected
|
||||
|
||||
|
||||
def test_explicit_strings_preserved(setup_env_vars):
|
||||
# Explicit strings that look like numbers/booleans should remain strings
|
||||
data = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
expected = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
assert replace_env_vars_compat(data) == expected
|
||||
assert replace_env_vars(data) == expected
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue