feat: use SecretStr for inference provider auth credentials (#3724)

# What does this PR do?

use SecretStr for OpenAIMixin providers

- RemoteInferenceProviderConfig now has auth_credential: SecretStr
- the default alias is api_key (most common name)
- some providers override to use api_token (RunPod, vLLM, Databricks)
- some providers exclude it (Ollama, TGI, Vertex AI)

addresses #3517 

## Test Plan

ci w/ new tests
This commit is contained in:
Matthew Farrellee 2025-10-10 10:32:50 -04:00 committed by GitHub
parent 6d8f61206e
commit 0066d986c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 158 additions and 149 deletions

View file

@ -16,7 +16,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for Anthropic models | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
## Sample Configuration ## Sample Configuration

View file

@ -23,7 +23,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) | | `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |

View file

@ -16,8 +16,8 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ Databricks inference provider for running models on Databricks' unified analytic
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The Databricks API token |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint | | `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
## Sample Configuration ## Sample Configuration

View file

@ -16,7 +16,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for Gemini models | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
## Sample Configuration ## Sample Configuration

View file

@ -16,7 +16,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | The Groq API key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server | | `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
## Sample Configuration ## Sample Configuration

View file

@ -16,7 +16,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | The Llama API key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | | `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests | | `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. | | `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |

View file

@ -16,7 +16,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `str \| None` | No | | API key for OpenAI models | | `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API | | `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ Passthrough inference provider for connecting to any external inference service
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
| `api_token` | `str \| None` | No | | The API token |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ Together AI inference provider for open-source models and collaborative AI devel
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
## Sample Configuration ## Sample Configuration

View file

@ -16,9 +16,9 @@ Remote vLLM inference provider for connecting to vLLM servers.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
## Sample Configuration ## Sample Configuration

View file

@ -16,8 +16,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
| `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests | | `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |

View file

@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000}, # "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# } # }
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self): def get_base_url(self):
return "https://api.anthropic.com/v1" return "https://api.anthropic.com/v1"

View file

@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class AnthropicConfig(RemoteInferenceProviderConfig): class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {

View file

@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "azure_api_key" provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
Get the Azure API base URL. Get the Azure API base URL.

View file

@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class AzureConfig(RemoteInferenceProviderConfig): class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field( api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)", description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
) )

View file

@ -15,9 +15,6 @@ from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin): class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig config: CerebrasImplConfig
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str: def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1") return urljoin(self.config.base_url, "v1")

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API", description="Base URL for the Cerebras API",
) )
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
default=None, default=None,
description="The URL for the Databricks model serving endpoint", description="The URL for the Databricks model serving endpoint",
) )
api_token: SecretStr = Field( auth_credential: SecretStr | None = Field(
default=SecretStr(None), # type: ignore[arg-type] default=None,
alias="api_token",
description="The Databricks API token", description="The Databricks API token",
) )

View file

@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512}, "databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
} }
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints" return f"{self.config.url}/serving-endpoints"

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
default="https://api.fireworks.ai/inference/v1", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key" provider_data_api_key_field: str = "fireworks_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
def get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"

View file

@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class GeminiConfig(RemoteInferenceProviderConfig): class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {

View file

@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin):
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
} }
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self): def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/" return "https://generativelanguage.googleapis.com/v1beta/openai/"

View file

@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class GroqConfig(RemoteInferenceProviderConfig): class GroqConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
url: str = Field( url: str = Field(
default="https://api.groq.com", default="https://api.groq.com",
description="The URL for the Groq AI server", description="The URL for the Groq AI server",

View file

@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key" provider_data_api_key_field: str = "groq_api_key"
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1" return f"{self.config.url}/openai/v1"

View file

@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig): class LlamaCompatConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field( openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/", default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server", description="The URL for the Llama API server",

View file

@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
Llama API Inference Adapter for Llama Stack. Llama API Inference Adapter for Llama Stack.
""" """
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
Get the base URL for OpenAI mixin. Get the base URL for OpenAI mixin.

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"), default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field( timeout: int = Field(
default=60, default=60,
description="Timeout for the HTTP requests", description="Timeout for the HTTP requests",

View file

@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...") logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(self.config): if _is_nvidia_hosted(self.config):
if not self.config.api_key: if not self.config.auth_credential:
raise RuntimeError( raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM." "API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
) )
@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API key :return: The NVIDIA API key
""" """
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY" if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """

View file

@ -6,12 +6,16 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(RemoteInferenceProviderConfig): class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
@classmethod @classmethod

View file

@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
return self._clients[loop] return self._clients[loop]
def get_api_key(self): def get_api_key(self):
return "NO_KEY" return "NO KEY REQUIRED"
def get_base_url(self): def get_base_url(self):
return self.config.url.rstrip("/") + "/v1" return self.config.url.rstrip("/") + "/v1"

View file

@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig): class OpenAIConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
base_url: str = Field( base_url: str = Field(
default="https://api.openai.com/v1", default="https://api.openai.com/v1",
description="Base URL for OpenAI API", description="Base URL for OpenAI API",

View file

@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192}, "text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
} }
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
Get the OpenAI API base URL. Get the OpenAI API base URL.

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
default=None, default=None,
description="The URL for the Runpod model serving endpoint", description="The URL for the Runpod model serving endpoint",
) )
api_token: str | None = Field( auth_credential: SecretStr | None = Field(
default=None, default=None,
alias="api_token",
description="The API token", description="The API token",
) )

View file

@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
config: RunpodImplConfig config: RunpodImplConfig
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get base URL for OpenAI client.""" """Get base URL for OpenAI client."""
return self.config.url return self.config.url

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
default="https://api.sambanova.ai/v1", default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server", description="The URL for the SambaNova AI server",
) )
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
SambaNova Inference Adapter for Llama Stack. SambaNova Inference Adapter for Llama Stack.
""" """
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else ""
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
Get the base URL for OpenAI mixin. Get the base URL for OpenAI mixin.

View file

@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig): class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field( url: str = Field(
description="The URL for the TGI serving endpoint", description="The URL for the TGI serving endpoint",
) )

View file

@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin):
overwrite_completion_id = True # TGI always returns id="" overwrite_completion_id = True # TGI always returns id=""
def get_api_key(self): def get_api_key(self):
return self.api_key.get_secret_value() return "NO KEY REQUIRED"
def get_base_url(self): def get_base_url(self):
return self.url return self.url

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",
) )
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]: def sample_run_config(cls, **kwargs) -> dict[str, Any]:

View file

@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key" provider_data_api_key_field: str = "together_api_key"
def get_api_key(self):
return self.config.api_key.get_secret_value() if self.config.api_key else None
def get_base_url(self): def get_base_url(self):
return BASE_URL return BASE_URL
def _get_client(self) -> AsyncTogether: def _get_client(self) -> AsyncTogether:
together_api_key = None together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key: if config_api_key:
together_api_key = config_api_key together_api_key = config_api_key
else: else:

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class VertexAIConfig(RemoteInferenceProviderConfig): class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field( project: str = Field(
description="Google Cloud project ID for Vertex AI", description="Google Cloud project ID for Vertex AI",
) )

View file

@ -6,7 +6,7 @@
from pathlib import Path from pathlib import Path
from pydantic import Field, field_validator from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=4096, default=4096,
description="Maximum number of tokens to generate.", description="Maximum number of tokens to generate.",
) )
api_token: str | None = Field( auth_credential: SecretStr | None = Field(
default="fake", default=None,
alias="api_token",
description="The API token", description="The API token",
) )
tls_verify: bool | str = Field( tls_verify: bool | str = Field(

View file

@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "vllm_api_token" provider_data_api_key_field: str = "vllm_api_token"
def get_api_key(self) -> str: def get_api_key(self) -> str | None:
return self.config.api_token or "" if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get the base URL from config.""" """Get the base URL from config."""

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field, SecretStr from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai", description="A base url for accessing the watsonx.ai",
) )
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field(
default=None,
description="The watsonx.ai API key",
)
# As above, this is optional here too for consistency.
project_id: str | None = Field( project_id: str | None = Field(
default=None, default=None,
description="The watsonx.ai project ID", description="The watsonx.ai project ID",

View file

@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
litellm_provider_name="watsonx", litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
provider_data_api_key_field="watsonx_api_key", provider_data_api_key_field="watsonx_api_key",
) )
self.available_models = None self.available_models = None

View file

@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
class BedrockBaseConfig(RemoteInferenceProviderConfig): class BedrockBaseConfig(RemoteInferenceProviderConfig):
auth_credential: None = Field(default=None, exclude=True)
aws_access_key_id: str | None = Field( aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
default=False, default=False,
description="Whether to refresh models periodically from the provider", description="Whether to refresh models periodically from the provider",
) )
auth_credential: SecretStr | None = Field(
default=None,
description="Authentication credential for the provider",
alias="api_key",
)
# TODO: this class is more confusing than useful right now. We need to make it # TODO: this class is more confusing than useful right now. We need to make it

View file

@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
This class handles direct OpenAI API calls using the AsyncOpenAI client. This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement: This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL - get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways: The behavior of this class can be customized by child classes in the following ways:
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# Optional field name in provider data to look for API key, which takes precedence # Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None provider_data_api_key_field: str | None = None
@abstractmethod def get_api_key(self) -> str | None:
def get_api_key(self) -> str:
""" """
Get the API key. Get the API key.
This method must be implemented by child classes to provide the API key :return: The API key as a string, or None if not set
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
""" """
pass if self.config.auth_credential is None:
return None
return self.config.auth_credential.get_secret_value()
@abstractmethod @abstractmethod
def get_base_url(self) -> str: def get_base_url(self) -> str:
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None): if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field) api_key = getattr(provider_data, self.provider_data_api_key_field)
if not api_key: # TODO: let get_api_key return None if not api_key:
raise ValueError( message = "API key not provided."
"API key is not set. Please provide a valid API key in the " if self.provider_data_api_key_field:
"provider data header, e.g. x-llamastack-provider-data: " message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, ' raise ValueError(message)
"or in the provider config."
)
return AsyncOpenAI( return AsyncOpenAI(
api_key=api_key, api_key=api_key,

View file

@ -76,6 +76,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
fields_info = {} fields_info = {}
if hasattr(config_class, "model_fields"): if hasattr(config_class, "model_fields"):
for field_name, field in config_class.model_fields.items(): for field_name, field in config_class.model_fields.items():
if getattr(field, "exclude", False):
continue
field_type = str(field.annotation) if field.annotation else "Any" field_type = str(field.annotation) if field.annotation else "Any"
# this string replace is ridiculous # this string replace is ridiculous
@ -106,7 +108,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
"default": default_value, "default": default_value,
"required": field.default is None and not field.is_required, "required": field.default is None and not field.is_required,
} }
fields_info[field_name] = field_info
# Use alias if available, otherwise use the field name
display_name = field.alias if field.alias else field_name
fields_info[display_name] = field_info
if accepts_extra_config: if accepts_extra_config:
config_description = "Additional configuration options that will be forwarded to the underlying provider" config_description = "Additional configuration options that will be forwarded to the underlying provider"

View file

@ -720,7 +720,7 @@ class TestOpenAIMixinProviderDataApiKey:
): ):
"""Test that ValueError is raised when provider data exists but doesn't have required key""" """Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"): with pytest.raises(ValueError, match="API key not provided"):
_ = mixin_with_provider_data_field_and_none_api_key.client _ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key): def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
from llama_stack.providers.remote.inference.azure import AzureConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.gemini import GeminiConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.openai import OpenAIConfig
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
class TestRemoteInferenceProviderConfig:
@pytest.mark.parametrize(
"config_cls,alias_name,env_name,extra_config",
[
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
(BedrockConfig, None, None, {}),
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
(OllamaImplConfig, None, None, {}),
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
(TGIImplConfig, None, None, {"url": "FAKE"}),
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
],
)
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
"""Test that the config class correctly maps the alias to auth_credential."""
secret_value = config_cls.__name__
if alias_name is None:
pytest.skip("No alias name provided for this config class.")
config = config_cls(**{alias_name: secret_value, **extra_config})
assert config.auth_credential is not None
assert config.auth_credential.get_secret_value() == secret_value
schema = config_cls.model_json_schema()
assert alias_name in schema["properties"]
assert "auth_credential" not in schema["properties"]
if env_name:
monkeypatch.setenv(env_name, secret_value)
sample_config = config_cls.sample_run_config()
expanded_config = replace_env_vars(sample_config)
config_from_sample = config_cls(**{**expanded_config, **extra_config})
assert config_from_sample.auth_credential is not None
assert config_from_sample.auth_credential.get_secret_value() == secret_value