mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
feat: use SecretStr for inference provider auth credentials (#3724)
# What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests
This commit is contained in:
parent
6d8f61206e
commit
0066d986c5
57 changed files with 158 additions and 149 deletions
|
@ -16,7 +16,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
||||
|
|
|
@ -16,8 +16,8 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ Databricks inference provider for running models on Databricks' unified analytic
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The Databricks API token |
|
||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
||||
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | The Groq API key |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -16,7 +16,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | The Llama API key |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -16,8 +16,8 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -16,8 +16,8 @@ Passthrough inference provider for connecting to any external inference service
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
||||
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
|
||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||
| `api_token` | `str \| None` | No | | The API token |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ Together AI inference provider for open-source models and collaborative AI devel
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
|
||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||
| `api_token` | `str \| None` | No | fake | The API token |
|
||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||
|
||||
## Sample Configuration
|
||||
|
|
|
@ -16,8 +16,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
|
||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
|
||||
|
|
|
@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
|
|
|
@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
|
|
@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
|
|
|
@ -15,9 +15,6 @@ from .config import CerebrasImplConfig
|
|||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None), # type: ignore[arg-type]
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
|
|
@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
|
|
@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
|
|
|
@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
|
|
@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
|
|
|
@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
Llama API Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
|
|
|
@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.api_key:
|
||||
if not self.config.auth_credential:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
|
@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
if not _is_nvidia_hosted(self.config):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
return None
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
|
@ -6,12 +6,16 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
|
|
@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
|
|
|
@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
|||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key or ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
|||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
|
|
|
@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
|||
SambaNova Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin):
|
|||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_api_key(self):
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
api_token: str | None = Field(
|
||||
default="fake",
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
|
|
|
@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token or ""
|
||||
def get_api_key(self) -> str | None:
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
# This seems like it should be required, but none of the other remote inference
|
||||
# providers require it, so this is optional here too for consistency.
|
||||
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai API key",
|
||||
)
|
||||
# As above, this is optional here too for consistency.
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
|
|
|
@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
|
|||
|
||||
|
||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: None = Field(default=None, exclude=True)
|
||||
aws_access_key_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
|
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Authentication credential for the provider",
|
||||
alias="api_key",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
|
|
|
@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
||||
This is an abstract base class that requires child classes to implement:
|
||||
- get_api_key(): Method to retrieve the API key
|
||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||
|
||||
The behavior of this class can be customized by child classes in the following ways:
|
||||
|
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
def get_api_key(self) -> str | None:
|
||||
"""
|
||||
Get the API key.
|
||||
|
||||
This method must be implemented by child classes to provide the API key
|
||||
for authenticating with the OpenAI API or compatible endpoints.
|
||||
|
||||
:return: The API key as a string
|
||||
:return: The API key as a string, or None if not set
|
||||
"""
|
||||
pass
|
||||
if self.config.auth_credential is None:
|
||||
return None
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
@abstractmethod
|
||||
def get_base_url(self) -> str:
|
||||
|
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
if not api_key: # TODO: let get_api_key return None
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
||||
"or in the provider config."
|
||||
)
|
||||
if not api_key:
|
||||
message = "API key not provided."
|
||||
if self.provider_data_api_key_field:
|
||||
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
||||
raise ValueError(message)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
|
|
|
@ -76,6 +76,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
fields_info = {}
|
||||
if hasattr(config_class, "model_fields"):
|
||||
for field_name, field in config_class.model_fields.items():
|
||||
if getattr(field, "exclude", False):
|
||||
continue
|
||||
field_type = str(field.annotation) if field.annotation else "Any"
|
||||
|
||||
# this string replace is ridiculous
|
||||
|
@ -106,7 +108,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
"default": default_value,
|
||||
"required": field.default is None and not field.is_required,
|
||||
}
|
||||
fields_info[field_name] = field_info
|
||||
|
||||
# Use alias if available, otherwise use the field name
|
||||
display_name = field.alias if field.alias else field_name
|
||||
fields_info[display_name] = field_info
|
||||
|
||||
if accepts_extra_config:
|
||||
config_description = "Additional configuration options that will be forwarded to the underlying provider"
|
||||
|
|
|
@ -720,7 +720,7 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
with pytest.raises(ValueError, match="API key not provided"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.azure import AzureConfig
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.gemini import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.openai import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
|
||||
|
||||
class TestRemoteInferenceProviderConfig:
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,alias_name,env_name,extra_config",
|
||||
[
|
||||
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
|
||||
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
|
||||
(BedrockConfig, None, None, {}),
|
||||
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
|
||||
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
|
||||
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
|
||||
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
|
||||
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
|
||||
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
|
||||
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
|
||||
(OllamaImplConfig, None, None, {}),
|
||||
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
|
||||
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
|
||||
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
|
||||
(TGIImplConfig, None, None, {"url": "FAKE"}),
|
||||
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
|
||||
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
|
||||
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
|
||||
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
|
||||
],
|
||||
)
|
||||
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
|
||||
"""Test that the config class correctly maps the alias to auth_credential."""
|
||||
secret_value = config_cls.__name__
|
||||
|
||||
if alias_name is None:
|
||||
pytest.skip("No alias name provided for this config class.")
|
||||
|
||||
config = config_cls(**{alias_name: secret_value, **extra_config})
|
||||
assert config.auth_credential is not None
|
||||
assert config.auth_credential.get_secret_value() == secret_value
|
||||
|
||||
schema = config_cls.model_json_schema()
|
||||
assert alias_name in schema["properties"]
|
||||
assert "auth_credential" not in schema["properties"]
|
||||
|
||||
if env_name:
|
||||
monkeypatch.setenv(env_name, secret_value)
|
||||
sample_config = config_cls.sample_run_config()
|
||||
expanded_config = replace_env_vars(sample_config)
|
||||
config_from_sample = config_cls(**{**expanded_config, **extra_config})
|
||||
assert config_from_sample.auth_credential is not None
|
||||
assert config_from_sample.auth_credential.get_secret_value() == secret_value
|
Loading…
Add table
Add a link
Reference in a new issue