feat: use SecretStr for inference provider auth credentials (#3724)

# What does this PR do?

use SecretStr for OpenAIMixin providers

- RemoteInferenceProviderConfig now has auth_credential: SecretStr
- the default alias is api_key (most common name)
- some providers override to use api_token (RunPod, vLLM, Databricks)
- some providers exclude it (Ollama, TGI, Vertex AI)

addresses #3517 

## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-10 10:32:50 -04:00 committed by GitHub
parent 6d8f61206e
commit 0066d986c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 158 additions and 149 deletions

View file

@ -720,7 +720,7 @@ class TestOpenAIMixinProviderDataApiKey:
):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
with pytest.raises(ValueError, match="API key not provided"):
_ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
from llama_stack.providers.remote.inference.azure import AzureConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.gemini import GeminiConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.openai import OpenAIConfig
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
class TestRemoteInferenceProviderConfig:
@pytest.mark.parametrize(
"config_cls,alias_name,env_name,extra_config",
[
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
(BedrockConfig, None, None, {}),
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
(OllamaImplConfig, None, None, {}),
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
(TGIImplConfig, None, None, {"url": "FAKE"}),
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
],
)
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
"""Test that the config class correctly maps the alias to auth_credential."""
secret_value = config_cls.__name__
if alias_name is None:
pytest.skip("No alias name provided for this config class.")
config = config_cls(**{alias_name: secret_value, **extra_config})
assert config.auth_credential is not None
assert config.auth_credential.get_secret_value() == secret_value
schema = config_cls.model_json_schema()
assert alias_name in schema["properties"]
assert "auth_credential" not in schema["properties"]
if env_name:
monkeypatch.setenv(env_name, secret_value)
sample_config = config_cls.sample_run_config()
expanded_config = replace_env_vars(sample_config)
config_from_sample = config_cls(**{**expanded_config, **extra_config})
assert config_from_sample.auth_credential is not None
assert config_from_sample.auth_credential.get_secret_value() == secret_value