feat: use SecretStr for inference provider auth credentials (#3724)

# What does this PR do?

use SecretStr for OpenAIMixin providers

- RemoteInferenceProviderConfig now has auth_credential: SecretStr
- the default alias is api_key (most common name)
- some providers override to use api_token (RunPod, vLLM, Databricks)
- some providers exclude it (Ollama, TGI, Vertex AI)

addresses #3517 

## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-10 10:32:50 -04:00 committed by GitHub
parent 6d8f61206e
commit 0066d986c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 158 additions and 149 deletions

View file

@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://api.anthropic.com/v1"

View file

@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
"""
Get the Azure API base URL.

View file

@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)

View file

@ -15,9 +15,6 @@ from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None), # type: ignore[arg-type]
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The Databricks API token",
)

View file

@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"

View file

@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin):
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"

View file

@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",

View file

@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key"
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"

View file

@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",

View file

@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
Llama API Inference Adapter for Llama Stack.
"""
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(self.config):
if not self.config.api_key:
if not self.config.auth_credential:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API key
"""
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str:
"""

View file

@ -6,12 +6,16 @@
from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
@classmethod

View file

@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
return self._clients[loop]
def get_api_key(self):
return "NO_KEY"
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"

View file

@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",

View file

@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
Get the OpenAI API base URL.

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_token: str | None = Field(
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)

View file

@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
config: RunpodImplConfig
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
SambaNova Inference Adapter for Llama Stack.
"""
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else ""
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.

View file

@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin):
overwrite_completion_id = True # TGI always returns id=""
def get_api_key(self):
return self.api_key.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:

View file

@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key"
def get_api_key(self):
return self.config.api_key.get_secret_value() if self.config.api_key else None
def get_base_url(self):
return BASE_URL
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key:
together_api_key = config_api_key
else:

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -6,7 +6,7 @@
from pathlib import Path
from pydantic import Field, field_validator
from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: str | None = Field(
default="fake",
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
tls_verify: bool | str = Field(

View file

@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str:
return self.config.api_token or ""
def get_api_key(self) -> str | None:
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str:
"""Get the base URL from config."""

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
default=None,
description="The watsonx.ai API key",
)
# As above, this is optional here too for consistency.
project_id: str | None = Field(
default=None,
description="The watsonx.ai project ID",

View file

@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None

View file

@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
class BedrockBaseConfig(RemoteInferenceProviderConfig):
auth_credential: None = Field(default=None, exclude=True)
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
default=False,
description="Whether to refresh models periodically from the provider",
)
auth_credential: SecretStr | None = Field(
default=None,
description="Authentication credential for the provider",
alias="api_key",
)
# TODO: this class is more confusing than useful right now. We need to make it

View file

@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways:
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@abstractmethod
def get_api_key(self) -> str:
def get_api_key(self) -> str | None:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
:return: The API key as a string, or None if not set
"""
pass
if self.config.auth_credential is None:
return None
return self.config.auth_credential.get_secret_value()
@abstractmethod
def get_base_url(self) -> str:
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field)
if not api_key: # TODO: let get_api_key return None
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider data header, e.g. x-llamastack-provider-data: "
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
"or in the provider config."
)
if not api_key:
message = "API key not provided."
if self.provider_data_api_key_field:
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
raise ValueError(message)
return AsyncOpenAI(
api_key=api_key,