From 0066d986c5538ad45e8c9d84dc647ecc764780e0 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 10 Oct 2025 10:32:50 -0400 Subject: [PATCH] feat: use SecretStr for inference provider auth credentials (#3724) # What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests --- .../providers/inference/remote_anthropic.mdx | 2 +- .../docs/providers/inference/remote_azure.mdx | 2 +- .../providers/inference/remote_cerebras.mdx | 2 +- .../providers/inference/remote_databricks.mdx | 2 +- .../providers/inference/remote_fireworks.mdx | 2 +- .../providers/inference/remote_gemini.mdx | 2 +- docs/docs/providers/inference/remote_groq.mdx | 2 +- .../inference/remote_llama-openai-compat.mdx | 2 +- .../providers/inference/remote_nvidia.mdx | 2 +- .../providers/inference/remote_openai.mdx | 2 +- .../inference/remote_passthrough.mdx | 2 +- .../providers/inference/remote_runpod.mdx | 2 +- .../providers/inference/remote_sambanova.mdx | 2 +- .../providers/inference/remote_together.mdx | 2 +- docs/docs/providers/inference/remote_vllm.mdx | 2 +- .../providers/inference/remote_watsonx.mdx | 2 +- .../remote/inference/anthropic/anthropic.py | 3 - .../remote/inference/anthropic/config.py | 5 -- .../providers/remote/inference/azure/azure.py | 3 - .../remote/inference/azure/config.py | 3 - .../remote/inference/cerebras/cerebras.py | 3 - .../remote/inference/cerebras/config.py | 6 +- .../remote/inference/databricks/config.py | 5 +- .../remote/inference/databricks/databricks.py | 3 - .../remote/inference/fireworks/config.py | 6 +- .../remote/inference/fireworks/fireworks.py | 3 - .../remote/inference/gemini/config.py | 5 -- .../remote/inference/gemini/gemini.py | 3 - .../providers/remote/inference/groq/config.py | 6 -- .../providers/remote/inference/groq/groq.py | 3 - .../inference/llama_openai_compat/config.py | 5 -- .../inference/llama_openai_compat/llama.py | 3 - .../remote/inference/nvidia/config.py | 6 +- .../remote/inference/nvidia/nvidia.py | 10 ++- .../remote/inference/ollama/config.py | 4 + .../remote/inference/ollama/ollama.py | 2 +- .../remote/inference/openai/config.py | 4 - .../remote/inference/openai/openai.py | 3 - .../remote/inference/runpod/config.py | 5 +- .../remote/inference/runpod/runpod.py | 4 - .../remote/inference/sambanova/config.py | 6 +- .../remote/inference/sambanova/sambanova.py | 3 - .../providers/remote/inference/tgi/config.py | 2 + .../providers/remote/inference/tgi/tgi.py | 2 +- .../remote/inference/together/config.py | 6 +- .../remote/inference/together/together.py | 5 +- .../remote/inference/vertexai/config.py | 4 +- .../providers/remote/inference/vllm/config.py | 7 +- .../providers/remote/inference/vllm/vllm.py | 6 +- .../remote/inference/watsonx/config.py | 10 +-- .../remote/inference/watsonx/watsonx.py | 2 +- llama_stack/providers/utils/bedrock/config.py | 1 + .../utils/inference/model_registry.py | 7 +- .../providers/utils/inference/openai_mixin.py | 25 +++--- scripts/provider_codegen.py | 7 +- .../utils/inference/test_openai_mixin.py | 2 +- .../test_remote_inference_provider_config.py | 77 +++++++++++++++++++ 57 files changed, 158 insertions(+), 149 deletions(-) create mode 100644 tests/unit/providers/utils/inference/test_remote_inference_provider_config.py diff --git a/docs/docs/providers/inference/remote_anthropic.mdx b/docs/docs/providers/inference/remote_anthropic.mdx index 44c1fcbb1..4acbbac50 100644 --- a/docs/docs/providers/inference/remote_anthropic.mdx +++ b/docs/docs/providers/inference/remote_anthropic.mdx @@ -16,7 +16,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `str \| None` | No | | API key for Anthropic models | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_azure.mdx b/docs/docs/providers/inference/remote_azure.mdx index 56a14c100..b3041259e 100644 --- a/docs/docs/providers/inference/remote_azure.mdx +++ b/docs/docs/providers/inference/remote_azure.mdx @@ -23,7 +23,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `` | No | | Azure API key for Azure | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `api_base` | `` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | | `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) | diff --git a/docs/docs/providers/inference/remote_cerebras.mdx b/docs/docs/providers/inference/remote_cerebras.mdx index d364b9884..cda0be224 100644 --- a/docs/docs/providers/inference/remote_cerebras.mdx +++ b/docs/docs/providers/inference/remote_cerebras.mdx @@ -16,8 +16,8 @@ Cerebras inference provider for running models on Cerebras Cloud platform. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `base_url` | `` | No | https://api.cerebras.ai | Base URL for the Cerebras API | -| `api_key` | `` | No | | Cerebras API Key | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_databricks.mdx b/docs/docs/providers/inference/remote_databricks.mdx index d7b0bd38d..f14fd0175 100644 --- a/docs/docs/providers/inference/remote_databricks.mdx +++ b/docs/docs/providers/inference/remote_databricks.mdx @@ -16,8 +16,8 @@ Databricks inference provider for running models on Databricks' unified analytic |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_token` | `pydantic.types.SecretStr \| None` | No | | The Databricks API token | | `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint | -| `api_token` | `` | No | | The Databricks API token | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_fireworks.mdx b/docs/docs/providers/inference/remote_fireworks.mdx index cfdfb993c..71f16ccec 100644 --- a/docs/docs/providers/inference/remote_fireworks.mdx +++ b/docs/docs/providers/inference/remote_fireworks.mdx @@ -16,8 +16,8 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_gemini.mdx b/docs/docs/providers/inference/remote_gemini.mdx index a13d1c82d..22b3c8cb7 100644 --- a/docs/docs/providers/inference/remote_gemini.mdx +++ b/docs/docs/providers/inference/remote_gemini.mdx @@ -16,7 +16,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `str \| None` | No | | API key for Gemini models | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_groq.mdx b/docs/docs/providers/inference/remote_groq.mdx index 1edb4f9ea..aaf1516ca 100644 --- a/docs/docs/providers/inference/remote_groq.mdx +++ b/docs/docs/providers/inference/remote_groq.mdx @@ -16,7 +16,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `str \| None` | No | | The Groq API key | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://api.groq.com | The URL for the Groq AI server | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_llama-openai-compat.mdx b/docs/docs/providers/inference/remote_llama-openai-compat.mdx index ca5830b09..9769c0793 100644 --- a/docs/docs/providers/inference/remote_llama-openai-compat.mdx +++ b/docs/docs/providers/inference/remote_llama-openai-compat.mdx @@ -16,7 +16,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `str \| None` | No | | The Llama API key | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `openai_compat_api_base` | `` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index 6b5e36180..b4e04176c 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -16,8 +16,8 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | | `append_api_version` | `` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. | diff --git a/docs/docs/providers/inference/remote_openai.mdx b/docs/docs/providers/inference/remote_openai.mdx index e0910c809..28c8ab7bf 100644 --- a/docs/docs/providers/inference/remote_openai.mdx +++ b/docs/docs/providers/inference/remote_openai.mdx @@ -16,7 +16,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `api_key` | `str \| None` | No | | API key for OpenAI models | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx index e356384ad..7a2931690 100644 --- a/docs/docs/providers/inference/remote_passthrough.mdx +++ b/docs/docs/providers/inference/remote_passthrough.mdx @@ -16,8 +16,8 @@ Passthrough inference provider for connecting to any external inference service |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `url` | `` | No | | The URL for the passthrough endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | +| `url` | `` | No | | The URL for the passthrough endpoint | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_runpod.mdx b/docs/docs/providers/inference/remote_runpod.mdx index 876532029..3cbbd0322 100644 --- a/docs/docs/providers/inference/remote_runpod.mdx +++ b/docs/docs/providers/inference/remote_runpod.mdx @@ -16,8 +16,8 @@ RunPod inference provider for running models on RunPod's cloud GPU platform. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | -| `api_token` | `str \| None` | No | | The API token | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_sambanova.mdx b/docs/docs/providers/inference/remote_sambanova.mdx index 9bd7b7613..0ac4600b7 100644 --- a/docs/docs/providers/inference/remote_sambanova.mdx +++ b/docs/docs/providers/inference/remote_sambanova.mdx @@ -16,8 +16,8 @@ SambaNova inference provider for running models on SambaNova's dataflow architec |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_together.mdx b/docs/docs/providers/inference/remote_together.mdx index 6df2ca866..c8e3bcdcf 100644 --- a/docs/docs/providers/inference/remote_together.mdx +++ b/docs/docs/providers/inference/remote_together.mdx @@ -16,8 +16,8 @@ Together AI inference provider for open-source models and collaborative AI devel |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index fbbd424a3..f844bcee0 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -16,9 +16,9 @@ Remote vLLM inference provider for connecting to vLLM servers. |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | -| `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index f081703ab..2227aa1cc 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -16,8 +16,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key | | `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | diff --git a/llama_stack/providers/remote/inference/anthropic/anthropic.py b/llama_stack/providers/remote/inference/anthropic/anthropic.py index 3b996b16e..dc9d8fb40 100644 --- a/llama_stack/providers/remote/inference/anthropic/anthropic.py +++ b/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin): # "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000}, # } - def get_api_key(self) -> str: - return self.config.api_key or "" - def get_base_url(self): return "https://api.anthropic.com/v1" diff --git a/llama_stack/providers/remote/inference/anthropic/config.py b/llama_stack/providers/remote/inference/anthropic/config.py index de523ca5a..31e6aa12b 100644 --- a/llama_stack/providers/remote/inference/anthropic/config.py +++ b/llama_stack/providers/remote/inference/anthropic/config.py @@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel): @json_schema_type class AnthropicConfig(RemoteInferenceProviderConfig): - api_key: str | None = Field( - default=None, - description="API key for Anthropic models", - ) - @classmethod def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]: return { diff --git a/llama_stack/providers/remote/inference/azure/azure.py b/llama_stack/providers/remote/inference/azure/azure.py index 0c8f6e7ad..134d01b15 100644 --- a/llama_stack/providers/remote/inference/azure/azure.py +++ b/llama_stack/providers/remote/inference/azure/azure.py @@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin): provider_data_api_key_field: str = "azure_api_key" - def get_api_key(self) -> str: - return self.config.api_key.get_secret_value() - def get_base_url(self) -> str: """ Get the Azure API base URL. diff --git a/llama_stack/providers/remote/inference/azure/config.py b/llama_stack/providers/remote/inference/azure/config.py index 8bc7335a3..7c31df7a6 100644 --- a/llama_stack/providers/remote/inference/azure/config.py +++ b/llama_stack/providers/remote/inference/azure/config.py @@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel): @json_schema_type class AzureConfig(RemoteInferenceProviderConfig): - api_key: SecretStr = Field( - description="Azure API key for Azure", - ) api_base: HttpUrl = Field( description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)", ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 11ef218a1..0e24af0ee 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -15,9 +15,6 @@ from .config import CerebrasImplConfig class CerebrasInferenceAdapter(OpenAIMixin): config: CerebrasImplConfig - def get_api_key(self) -> str: - return self.config.api_key.get_secret_value() - def get_base_url(self) -> str: return urljoin(self.config.base_url, "v1") diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 40db38935..dc9a0f5fc 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig): default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), description="Base URL for the Cerebras API", ) - api_key: SecretStr = Field( - default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type] - description="Cerebras API Key", - ) @classmethod def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index 68e94151e..49d19cd35 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig): default=None, description="The URL for the Databricks model serving endpoint", ) - api_token: SecretStr = Field( - default=SecretStr(None), # type: ignore[arg-type] + auth_credential: SecretStr | None = Field( + default=None, + alias="api_token", description="The Databricks API token", ) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 200b36171..705f4bddd 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin): "databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512}, } - def get_api_key(self) -> str: - return self.config.api_token.get_secret_value() - def get_base_url(self) -> str: return f"{self.config.url}/serving-endpoints" diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index cd28096a5..20ba99606 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig): default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) - api_key: SecretStr | None = Field( - default=None, - description="The Fireworks.ai API Key", - ) @classmethod def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 81dbff0a3..7e2b73546 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin): provider_data_api_key_field: str = "fireworks_api_key" - def get_api_key(self) -> str: - return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value] - def get_base_url(self) -> str: return "https://api.fireworks.ai/inference/v1" diff --git a/llama_stack/providers/remote/inference/gemini/config.py b/llama_stack/providers/remote/inference/gemini/config.py index c7dacec96..df5da29a2 100644 --- a/llama_stack/providers/remote/inference/gemini/config.py +++ b/llama_stack/providers/remote/inference/gemini/config.py @@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel): @json_schema_type class GeminiConfig(RemoteInferenceProviderConfig): - api_key: str | None = Field( - default=None, - description="API key for Gemini models", - ) - @classmethod def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]: return { diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index ea7219a59..bb66b94d5 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin): "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, } - def get_api_key(self) -> str: - return self.config.api_key or "" - def get_base_url(self): return "https://generativelanguage.googleapis.com/v1beta/openai/" diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 23deba22e..c1aedca3e 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel): @json_schema_type class GroqConfig(RemoteInferenceProviderConfig): - api_key: str | None = Field( - # The Groq client library loads the GROQ_API_KEY environment variable by default - default=None, - description="The Groq API key", - ) - url: str = Field( default="https://api.groq.com", description="The URL for the Groq AI server", diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 21b37de36..3a4f2626d 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin): provider_data_api_key_field: str = "groq_api_key" - def get_api_key(self) -> str: - return self.config.api_key or "" - def get_base_url(self) -> str: return f"{self.config.url}/openai/v1" diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/config.py b/llama_stack/providers/remote/inference/llama_openai_compat/config.py index 0697c041d..4b5750ed4 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/config.py @@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel): @json_schema_type class LlamaCompatConfig(RemoteInferenceProviderConfig): - api_key: str | None = Field( - default=None, - description="The Llama API key", - ) - openai_compat_api_base: str = Field( default="https://api.llama.com/compat/v1/", description="The URL for the Llama API server", diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 165992c16..6995665f7 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin): Llama API Inference Adapter for Llama Stack. """ - def get_api_key(self) -> str: - return self.config.api_key or "" - def get_base_url(self) -> str: """ Get the base URL for OpenAI mixin. diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 4b310d770..2171877a5 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"), description="A base url for accessing the NVIDIA NIM", ) - api_key: SecretStr | None = Field( - default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")), - description="The NVIDIA API key, only needed of using the hosted service", - ) timeout: int = Field( default=60, description="Timeout for the HTTP requests", diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7a2697327..9d8d1089a 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin): logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...") if _is_nvidia_hosted(self.config): - if not self.config.api_key: + if not self.config.auth_credential: raise RuntimeError( "API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM." ) @@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API key """ - return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY" + if self.config.auth_credential: + return self.config.auth_credential.get_secret_value() + + if not _is_nvidia_hosted(self.config): + return "NO KEY REQUIRED" + + return None def get_base_url(self) -> str: """ diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index 1e4ce9113..416b847a0 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,12 +6,16 @@ from typing import Any +from pydantic import Field, SecretStr + from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(RemoteInferenceProviderConfig): + auth_credential: SecretStr | None = Field(default=None, exclude=True) + url: str = DEFAULT_OLLAMA_URL @classmethod diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 67d0caa54..50f36d045 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin): return self._clients[loop] def get_api_key(self): - return "NO_KEY" + return "NO KEY REQUIRED" def get_base_url(self): return self.config.url.rstrip("/") + "/v1" diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py index e494e967b..36c66bd28 100644 --- a/llama_stack/providers/remote/inference/openai/config.py +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel): @json_schema_type class OpenAIConfig(RemoteInferenceProviderConfig): - api_key: str | None = Field( - default=None, - description="API key for OpenAI models", - ) base_url: str = Field( default="https://api.openai.com/v1", description="Base URL for OpenAI API", diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index f68e8f9d6..52bc48f1a 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin): "text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192}, } - def get_api_key(self) -> str: - return self.config.api_key or "" - def get_base_url(self) -> str: """ Get the OpenAI API base URL. diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index cdfe0f885..3d16d20fd 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import Field +from pydantic import Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig): default=None, description="The URL for the Runpod model serving endpoint", ) - api_token: str | None = Field( + auth_credential: SecretStr | None = Field( default=None, + alias="api_token", description="The API token", ) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5..67e430ac5 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin): config: RunpodImplConfig - def get_api_key(self) -> str: - """Get API key for OpenAI client.""" - return self.config.api_token - def get_base_url(self) -> str: """Get base URL for OpenAI client.""" return self.config.url diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index a614663dc..f63210434 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig): default="https://api.sambanova.ai/v1", description="The URL for the SambaNova AI server", ) - api_key: SecretStr | None = Field( - default=None, - description="The SambaNova cloud API Key", - ) @classmethod def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index f30bab780..daa4b1670 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin): SambaNova Inference Adapter for Llama Stack. """ - def get_api_key(self) -> str: - return self.config.api_key.get_secret_value() if self.config.api_key else "" - def get_base_url(self) -> str: """ Get the base URL for OpenAI mixin. diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index d3110b2af..47952abba 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class TGIImplConfig(RemoteInferenceProviderConfig): + auth_credential: SecretStr | None = Field(default=None, exclude=True) + url: str = Field( description="The URL for the TGI serving endpoint", ) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a316e8996..da3205a13 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin): overwrite_completion_id = True # TGI always returns id="" def get_api_key(self): - return self.api_key.get_secret_value() + return "NO KEY REQUIRED" def get_base_url(self): return self.url diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index f6725333c..47392c8e7 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig): default="https://api.together.xyz/v1", description="The URL for the Together AI server", ) - api_key: SecretStr | None = Field( - default=None, - description="The Together AI API Key", - ) @classmethod def sample_run_config(cls, **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 224de6721..e29cccf04 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData): provider_data_api_key_field: str = "together_api_key" - def get_api_key(self): - return self.config.api_key.get_secret_value() if self.config.api_key else None - def get_base_url(self): return BASE_URL def _get_client(self) -> AsyncTogether: together_api_key = None - config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None if config_api_key: together_api_key = config_api_key else: diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py index 97d0852a8..5f2efa894 100644 --- a/llama_stack/providers/remote/inference/vertexai/config.py +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel): @json_schema_type class VertexAIConfig(RemoteInferenceProviderConfig): + auth_credential: SecretStr | None = Field(default=None, exclude=True) + project: str = Field( description="Google Cloud project ID for Vertex AI", ) diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 87c5408d3..e362aece6 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -6,7 +6,7 @@ from pathlib import Path -from pydantic import Field, field_validator +from pydantic import Field, SecretStr, field_validator from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=4096, description="Maximum number of tokens to generate.", ) - api_token: str | None = Field( - default="fake", + auth_credential: SecretStr | None = Field( + default=None, + alias="api_token", description="The API token", ) tls_verify: bool | str = Field( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 5974ca176..9e5f17c73 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin): provider_data_api_key_field: str = "vllm_api_token" - def get_api_key(self) -> str: - return self.config.api_token or "" + def get_api_key(self) -> str | None: + if self.config.auth_credential: + return self.config.auth_credential.get_secret_value() + return "NO KEY REQUIRED" def get_base_url(self) -> str: """Get the base URL from config.""" diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 9e98d4003..022dc5ee7 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", ) - # This seems like it should be required, but none of the other remote inference - # providers require it, so this is optional here too for consistency. - # The OpenAIConfig uses default=None instead, so this is following that precedent. - api_key: SecretStr | None = Field( - default=None, - description="The watsonx.ai API key", - ) - # As above, this is optional here too for consistency. project_id: str | None = Field( default=None, description="The watsonx.ai project ID", diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index d04472936..654d61f34 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, + api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, provider_data_api_key_field="watsonx_api_key", ) self.available_models = None diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index 418cf381b..7bddec348 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference class BedrockBaseConfig(RemoteInferenceProviderConfig): + auth_credential: None = Field(default=None, exclude=True) aws_access_key_id: str | None = Field( default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 9d42d68c6..d60d00f87 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -6,7 +6,7 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.models import ModelType @@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel): default=False, description="Whether to refresh models periodically from the provider", ) + auth_credential: SecretStr | None = Field( + default=None, + description="Authentication credential for the provider", + alias="api_key", + ) # TODO: this class is more confusing than useful right now. We need to make it diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index cba7508a2..33a8b81b5 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): This class handles direct OpenAI API calls using the AsyncOpenAI client. This is an abstract base class that requires child classes to implement: - - get_api_key(): Method to retrieve the API key - get_base_url(): Method to retrieve the OpenAI-compatible API base URL The behavior of this class can be customized by child classes in the following ways: @@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # Optional field name in provider data to look for API key, which takes precedence provider_data_api_key_field: str | None = None - @abstractmethod - def get_api_key(self) -> str: + def get_api_key(self) -> str | None: """ Get the API key. - This method must be implemented by child classes to provide the API key - for authenticating with the OpenAI API or compatible endpoints. - - :return: The API key as a string + :return: The API key as a string, or None if not set """ - pass + if self.config.auth_credential is None: + return None + return self.config.auth_credential.get_secret_value() @abstractmethod def get_base_url(self) -> str: @@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): if provider_data and getattr(provider_data, self.provider_data_api_key_field, None): api_key = getattr(provider_data, self.provider_data_api_key_field) - if not api_key: # TODO: let get_api_key return None - raise ValueError( - "API key is not set. Please provide a valid API key in the " - "provider data header, e.g. x-llamastack-provider-data: " - f'{{"{self.provider_data_api_key_field}": ""}}, ' - "or in the provider config." - ) + if not api_key: + message = "API key not provided." + if self.provider_data_api_key_field: + message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": ""}}.' + raise ValueError(message) return AsyncOpenAI( api_key=api_key, diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 34e4c0687..de79b4d17 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -76,6 +76,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: fields_info = {} if hasattr(config_class, "model_fields"): for field_name, field in config_class.model_fields.items(): + if getattr(field, "exclude", False): + continue field_type = str(field.annotation) if field.annotation else "Any" # this string replace is ridiculous @@ -106,7 +108,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: "default": default_value, "required": field.default is None and not field.is_required, } - fields_info[field_name] = field_info + + # Use alias if available, otherwise use the field name + display_name = field.alias if field.alias else field_name + fields_info[display_name] = field_info if accepts_extra_config: config_description = "Additional configuration options that will be forwarded to the underlying provider" diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index ad9406951..8ce4925e1 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -720,7 +720,7 @@ class TestOpenAIMixinProviderDataApiKey: ): """Test that ValueError is raised when provider data exists but doesn't have required key""" with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): - with pytest.raises(ValueError, match="API key is not set"): + with pytest.raises(ValueError, match="API key not provided"): _ = mixin_with_provider_data_field_and_none_api_key.client def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key): diff --git a/tests/unit/providers/utils/inference/test_remote_inference_provider_config.py b/tests/unit/providers/utils/inference/test_remote_inference_provider_config.py new file mode 100644 index 000000000..76c49900c --- /dev/null +++ b/tests/unit/providers/utils/inference/test_remote_inference_provider_config.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.core.stack import replace_env_vars +from llama_stack.providers.remote.inference.anthropic import AnthropicConfig +from llama_stack.providers.remote.inference.azure import AzureConfig +from llama_stack.providers.remote.inference.bedrock import BedrockConfig +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig +from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig +from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.gemini import GeminiConfig +from llama_stack.providers.remote.inference.groq import GroqConfig +from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig +from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.openai import OpenAIConfig +from llama_stack.providers.remote.inference.runpod import RunpodImplConfig +from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig +from llama_stack.providers.remote.inference.tgi import TGIImplConfig +from llama_stack.providers.remote.inference.together import TogetherImplConfig +from llama_stack.providers.remote.inference.vertexai import VertexAIConfig +from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.watsonx import WatsonXConfig + + +class TestRemoteInferenceProviderConfig: + @pytest.mark.parametrize( + "config_cls,alias_name,env_name,extra_config", + [ + (AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}), + (AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}), + (BedrockConfig, None, None, {}), + (CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}), + (DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}), + (FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}), + (GeminiConfig, "api_key", "GEMINI_API_KEY", {}), + (GroqConfig, "api_key", "GROQ_API_KEY", {}), + (LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}), + (NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}), + (OllamaImplConfig, None, None, {}), + (OpenAIConfig, "api_key", "OPENAI_API_KEY", {}), + (RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}), + (SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}), + (TGIImplConfig, None, None, {"url": "FAKE"}), + (TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}), + (VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}), + (VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}), + (WatsonXConfig, "api_key", "WATSONX_API_KEY", {}), + ], + ) + def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config): + """Test that the config class correctly maps the alias to auth_credential.""" + secret_value = config_cls.__name__ + + if alias_name is None: + pytest.skip("No alias name provided for this config class.") + + config = config_cls(**{alias_name: secret_value, **extra_config}) + assert config.auth_credential is not None + assert config.auth_credential.get_secret_value() == secret_value + + schema = config_cls.model_json_schema() + assert alias_name in schema["properties"] + assert "auth_credential" not in schema["properties"] + + if env_name: + monkeypatch.setenv(env_name, secret_value) + sample_config = config_cls.sample_run_config() + expanded_config = replace_env_vars(sample_config) + config_from_sample = config_cls(**{**expanded_config, **extra_config}) + assert config_from_sample.auth_credential is not None + assert config_from_sample.auth_credential.get_secret_value() == secret_value