mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
feat: use SecretStr for inference provider auth credentials (#3724)
# What does this PR do? use SecretStr for OpenAIMixin providers - RemoteInferenceProviderConfig now has auth_credential: SecretStr - the default alias is api_key (most common name) - some providers override to use api_token (RunPod, vLLM, Databricks) - some providers exclude it (Ollama, TGI, Vertex AI) addresses #3517 ## Test Plan ci w/ new tests
This commit is contained in:
parent
6d8f61206e
commit
0066d986c5
57 changed files with 158 additions and 149 deletions
|
@ -16,7 +16,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
||||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||||
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
||||||
|
|
|
@ -16,8 +16,8 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
||||||
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ Databricks inference provider for running models on Databricks' unified analytic
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The Databricks API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
||||||
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | The Groq API key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -16,7 +16,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | The Llama API key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -16,8 +16,8 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -16,8 +16,8 @@ Passthrough inference provider for connecting to any external inference service
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |
|
||||||
|
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||||
| `api_token` | `str \| None` | No | | The API token |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,9 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||||
| `api_token` | `str \| None` | No | fake | The API token |
|
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -16,8 +16,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
|
|
||||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,6 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
||||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||||
# }
|
# }
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key or ""
|
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return "https://api.anthropic.com/v1"
|
return "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
|
|
|
@ -21,11 +21,6 @@ class AnthropicProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for Anthropic models",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -16,9 +16,6 @@ class AzureInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
provider_data_api_key_field: str = "azure_api_key"
|
provider_data_api_key_field: str = "azure_api_key"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key.get_secret_value()
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get the Azure API base URL.
|
Get the Azure API base URL.
|
||||||
|
|
|
@ -32,9 +32,6 @@ class AzureProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AzureConfig(RemoteInferenceProviderConfig):
|
class AzureConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: SecretStr = Field(
|
|
||||||
description="Azure API key for Azure",
|
|
||||||
)
|
|
||||||
api_base: HttpUrl = Field(
|
api_base: HttpUrl = Field(
|
||||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,9 +15,6 @@ from .config import CerebrasImplConfig
|
||||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||||
config: CerebrasImplConfig
|
config: CerebrasImplConfig
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key.get_secret_value()
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return urljoin(self.config.base_url, "v1")
|
return urljoin(self.config.base_url, "v1")
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -21,10 +21,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||||
description="Base URL for the Cerebras API",
|
description="Base URL for the Cerebras API",
|
||||||
)
|
)
|
||||||
api_key: SecretStr = Field(
|
|
||||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
|
||||||
description="Cerebras API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -18,8 +18,9 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Databricks model serving endpoint",
|
description="The URL for the Databricks model serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: SecretStr = Field(
|
auth_credential: SecretStr | None = Field(
|
||||||
default=SecretStr(None), # type: ignore[arg-type]
|
default=None,
|
||||||
|
alias="api_token",
|
||||||
description="The Databricks API token",
|
description="The Databricks API token",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_token.get_secret_value()
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/serving-endpoints"
|
return f"{self.config.url}/serving-endpoints"
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||||
default="https://api.fireworks.ai/inference/v1",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The Fireworks.ai API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -23,8 +23,5 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
provider_data_api_key_field: str = "fireworks_api_key"
|
provider_data_api_key_field: str = "fireworks_api_key"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return "https://api.fireworks.ai/inference/v1"
|
||||||
|
|
|
@ -21,11 +21,6 @@ class GeminiProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for Gemini models",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -17,8 +17,5 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
||||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key or ""
|
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
|
|
|
@ -21,12 +21,6 @@ class GroqProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GroqConfig(RemoteInferenceProviderConfig):
|
class GroqConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: str | None = Field(
|
|
||||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
|
||||||
default=None,
|
|
||||||
description="The Groq API key",
|
|
||||||
)
|
|
||||||
|
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.groq.com",
|
default="https://api.groq.com",
|
||||||
description="The URL for the Groq AI server",
|
description="The URL for the Groq AI server",
|
||||||
|
|
|
@ -14,8 +14,5 @@ class GroqInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
provider_data_api_key_field: str = "groq_api_key"
|
provider_data_api_key_field: str = "groq_api_key"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key or ""
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/openai/v1"
|
return f"{self.config.url}/openai/v1"
|
||||||
|
|
|
@ -21,11 +21,6 @@ class LlamaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The Llama API key",
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_compat_api_base: str = Field(
|
openai_compat_api_base: str = Field(
|
||||||
default="https://api.llama.com/compat/v1/",
|
default="https://api.llama.com/compat/v1/",
|
||||||
description="The URL for the Llama API server",
|
description="The URL for the Llama API server",
|
||||||
|
|
|
@ -21,9 +21,6 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||||
Llama API Inference Adapter for Llama Stack.
|
Llama API Inference Adapter for Llama Stack.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key or ""
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get the base URL for OpenAI mixin.
|
Get the base URL for OpenAI mixin.
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -40,10 +40,6 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||||
description="A base url for accessing the NVIDIA NIM",
|
description="A base url for accessing the NVIDIA NIM",
|
||||||
)
|
)
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
|
|
||||||
description="The NVIDIA API key, only needed of using the hosted service",
|
|
||||||
)
|
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
|
|
|
@ -49,7 +49,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||||
|
|
||||||
if _is_nvidia_hosted(self.config):
|
if _is_nvidia_hosted(self.config):
|
||||||
if not self.config.api_key:
|
if not self.config.auth_credential:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||||
)
|
)
|
||||||
|
@ -60,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The NVIDIA API key
|
:return: The NVIDIA API key
|
||||||
"""
|
"""
|
||||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
if self.config.auth_credential:
|
||||||
|
return self.config.auth_credential.get_secret_value()
|
||||||
|
|
||||||
|
if not _is_nvidia_hosted(self.config):
|
||||||
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,12 +6,16 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||||
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
url: str = DEFAULT_OLLAMA_URL
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -59,7 +59,7 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
||||||
return self._clients[loop]
|
return self._clients[loop]
|
||||||
|
|
||||||
def get_api_key(self):
|
def get_api_key(self):
|
||||||
return "NO_KEY"
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.config.url.rstrip("/") + "/v1"
|
return self.config.url.rstrip("/") + "/v1"
|
||||||
|
|
|
@ -21,10 +21,6 @@ class OpenAIProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||||
api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for OpenAI models",
|
|
||||||
)
|
|
||||||
base_url: str = Field(
|
base_url: str = Field(
|
||||||
default="https://api.openai.com/v1",
|
default="https://api.openai.com/v1",
|
||||||
description="Base URL for OpenAI API",
|
description="Base URL for OpenAI API",
|
||||||
|
|
|
@ -29,9 +29,6 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
||||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key or ""
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get the OpenAI API base URL.
|
Get the OpenAI API base URL.
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -18,8 +18,9 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Runpod model serving endpoint",
|
description="The URL for the Runpod model serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: str | None = Field(
|
auth_credential: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
alias="api_token",
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
config: RunpodImplConfig
|
config: RunpodImplConfig
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
"""Get API key for OpenAI client."""
|
|
||||||
return self.config.api_token
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get base URL for OpenAI client."""
|
"""Get base URL for OpenAI client."""
|
||||||
return self.config.url
|
return self.config.url
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -25,10 +25,6 @@ class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||||
default="https://api.sambanova.ai/v1",
|
default="https://api.sambanova.ai/v1",
|
||||||
description="The URL for the SambaNova AI server",
|
description="The URL for the SambaNova AI server",
|
||||||
)
|
)
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The SambaNova cloud API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -19,9 +19,6 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||||
SambaNova Inference Adapter for Llama Stack.
|
SambaNova Inference Adapter for Llama Stack.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
|
||||||
return self.config.api_key.get_secret_value() if self.config.api_key else ""
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get the base URL for OpenAI mixin.
|
Get the base URL for OpenAI mixin.
|
||||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||||
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
description="The URL for the TGI serving endpoint",
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class _HfAdapter(OpenAIMixin):
|
||||||
overwrite_completion_id = True # TGI always returns id=""
|
overwrite_completion_id = True # TGI always returns id=""
|
||||||
|
|
||||||
def get_api_key(self):
|
def get_api_key(self):
|
||||||
return self.api_key.get_secret_value()
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.url
|
return self.url
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||||
default="https://api.together.xyz/v1",
|
default="https://api.together.xyz/v1",
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
)
|
)
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The Together AI API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -39,15 +39,12 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
|
|
||||||
provider_data_api_key_field: str = "together_api_key"
|
provider_data_api_key_field: str = "together_api_key"
|
||||||
|
|
||||||
def get_api_key(self):
|
|
||||||
return self.config.api_key.get_secret_value() if self.config.api_key else None
|
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return BASE_URL
|
return BASE_URL
|
||||||
|
|
||||||
def _get_client(self) -> AsyncTogether:
|
def _get_client(self) -> AsyncTogether:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
together_api_key = config_api_key
|
together_api_key = config_api_key
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -25,6 +25,8 @@ class VertexAIProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||||
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
project: str = Field(
|
project: str = Field(
|
||||||
description="Google Cloud project ID for Vertex AI",
|
description="Google Cloud project ID for Vertex AI",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, SecretStr, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -22,8 +22,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
default=4096,
|
default=4096,
|
||||||
description="Maximum number of tokens to generate.",
|
description="Maximum number of tokens to generate.",
|
||||||
)
|
)
|
||||||
api_token: str | None = Field(
|
auth_credential: SecretStr | None = Field(
|
||||||
default="fake",
|
default=None,
|
||||||
|
alias="api_token",
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
tls_verify: bool | str = Field(
|
tls_verify: bool | str = Field(
|
||||||
|
|
|
@ -38,8 +38,10 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
provider_data_api_key_field: str = "vllm_api_token"
|
provider_data_api_key_field: str = "vllm_api_token"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str | None:
|
||||||
return self.config.api_token or ""
|
if self.config.auth_credential:
|
||||||
|
return self.config.auth_credential.get_secret_value()
|
||||||
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get the base URL from config."""
|
"""Get the base URL from config."""
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
@ -27,14 +27,6 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
# This seems like it should be required, but none of the other remote inference
|
|
||||||
# providers require it, so this is optional here too for consistency.
|
|
||||||
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
|
||||||
api_key: SecretStr | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The watsonx.ai API key",
|
|
||||||
)
|
|
||||||
# As above, this is optional here too for consistency.
|
|
||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The watsonx.ai project ID",
|
description="The watsonx.ai project ID",
|
||||||
|
|
|
@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="watsonx",
|
litellm_provider_name="watsonx",
|
||||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
||||||
provider_data_api_key_field="watsonx_api_key",
|
provider_data_api_key_field="watsonx_api_key",
|
||||||
)
|
)
|
||||||
self.available_models = None
|
self.available_models = None
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.providers.utils.inference.model_registry import RemoteInference
|
||||||
|
|
||||||
|
|
||||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||||
|
auth_credential: None = Field(default=None, exclude=True)
|
||||||
aws_access_key_id: str | None = Field(
|
aws_access_key_id: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
|
@ -28,6 +28,11 @@ class RemoteInferenceProviderConfig(BaseModel):
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to refresh models periodically from the provider",
|
description="Whether to refresh models periodically from the provider",
|
||||||
)
|
)
|
||||||
|
auth_credential: SecretStr | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Authentication credential for the provider",
|
||||||
|
alias="api_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this class is more confusing than useful right now. We need to make it
|
# TODO: this class is more confusing than useful right now. We need to make it
|
||||||
|
|
|
@ -40,7 +40,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||||
|
|
||||||
This is an abstract base class that requires child classes to implement:
|
This is an abstract base class that requires child classes to implement:
|
||||||
- get_api_key(): Method to retrieve the API key
|
|
||||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||||
|
|
||||||
The behavior of this class can be customized by child classes in the following ways:
|
The behavior of this class can be customized by child classes in the following ways:
|
||||||
|
@ -87,17 +86,15 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
# Optional field name in provider data to look for API key, which takes precedence
|
# Optional field name in provider data to look for API key, which takes precedence
|
||||||
provider_data_api_key_field: str | None = None
|
provider_data_api_key_field: str | None = None
|
||||||
|
|
||||||
@abstractmethod
|
def get_api_key(self) -> str | None:
|
||||||
def get_api_key(self) -> str:
|
|
||||||
"""
|
"""
|
||||||
Get the API key.
|
Get the API key.
|
||||||
|
|
||||||
This method must be implemented by child classes to provide the API key
|
:return: The API key as a string, or None if not set
|
||||||
for authenticating with the OpenAI API or compatible endpoints.
|
|
||||||
|
|
||||||
:return: The API key as a string
|
|
||||||
"""
|
"""
|
||||||
pass
|
if self.config.auth_credential is None:
|
||||||
|
return None
|
||||||
|
return self.config.auth_credential.get_secret_value()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
|
@ -176,13 +173,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||||
|
|
||||||
if not api_key: # TODO: let get_api_key return None
|
if not api_key:
|
||||||
raise ValueError(
|
message = "API key not provided."
|
||||||
"API key is not set. Please provide a valid API key in the "
|
if self.provider_data_api_key_field:
|
||||||
"provider data header, e.g. x-llamastack-provider-data: "
|
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
||||||
f'{{"{self.provider_data_api_key_field}": "<API_KEY>"}}, '
|
raise ValueError(message)
|
||||||
"or in the provider config."
|
|
||||||
)
|
|
||||||
|
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -76,6 +76,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
||||||
fields_info = {}
|
fields_info = {}
|
||||||
if hasattr(config_class, "model_fields"):
|
if hasattr(config_class, "model_fields"):
|
||||||
for field_name, field in config_class.model_fields.items():
|
for field_name, field in config_class.model_fields.items():
|
||||||
|
if getattr(field, "exclude", False):
|
||||||
|
continue
|
||||||
field_type = str(field.annotation) if field.annotation else "Any"
|
field_type = str(field.annotation) if field.annotation else "Any"
|
||||||
|
|
||||||
# this string replace is ridiculous
|
# this string replace is ridiculous
|
||||||
|
@ -106,7 +108,10 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
||||||
"default": default_value,
|
"default": default_value,
|
||||||
"required": field.default is None and not field.is_required,
|
"required": field.default is None and not field.is_required,
|
||||||
}
|
}
|
||||||
fields_info[field_name] = field_info
|
|
||||||
|
# Use alias if available, otherwise use the field name
|
||||||
|
display_name = field.alias if field.alias else field_name
|
||||||
|
fields_info[display_name] = field_info
|
||||||
|
|
||||||
if accepts_extra_config:
|
if accepts_extra_config:
|
||||||
config_description = "Additional configuration options that will be forwarded to the underlying provider"
|
config_description = "Additional configuration options that will be forwarded to the underlying provider"
|
||||||
|
|
|
@ -720,7 +720,7 @@ class TestOpenAIMixinProviderDataApiKey:
|
||||||
):
|
):
|
||||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||||
with pytest.raises(ValueError, match="API key is not set"):
|
with pytest.raises(ValueError, match="API key not provided"):
|
||||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||||
|
|
||||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.core.stack import replace_env_vars
|
||||||
|
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
|
||||||
|
from llama_stack.providers.remote.inference.azure import AzureConfig
|
||||||
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.gemini import GeminiConfig
|
||||||
|
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.openai import OpenAIConfig
|
||||||
|
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
|
||||||
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoteInferenceProviderConfig:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_cls,alias_name,env_name,extra_config",
|
||||||
|
[
|
||||||
|
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
|
||||||
|
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
|
||||||
|
(BedrockConfig, None, None, {}),
|
||||||
|
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
|
||||||
|
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
|
||||||
|
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
|
||||||
|
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
|
||||||
|
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
|
||||||
|
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
|
||||||
|
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
|
||||||
|
(OllamaImplConfig, None, None, {}),
|
||||||
|
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
|
||||||
|
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
|
||||||
|
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
|
||||||
|
(TGIImplConfig, None, None, {"url": "FAKE"}),
|
||||||
|
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
|
||||||
|
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
|
||||||
|
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
|
||||||
|
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
|
||||||
|
"""Test that the config class correctly maps the alias to auth_credential."""
|
||||||
|
secret_value = config_cls.__name__
|
||||||
|
|
||||||
|
if alias_name is None:
|
||||||
|
pytest.skip("No alias name provided for this config class.")
|
||||||
|
|
||||||
|
config = config_cls(**{alias_name: secret_value, **extra_config})
|
||||||
|
assert config.auth_credential is not None
|
||||||
|
assert config.auth_credential.get_secret_value() == secret_value
|
||||||
|
|
||||||
|
schema = config_cls.model_json_schema()
|
||||||
|
assert alias_name in schema["properties"]
|
||||||
|
assert "auth_credential" not in schema["properties"]
|
||||||
|
|
||||||
|
if env_name:
|
||||||
|
monkeypatch.setenv(env_name, secret_value)
|
||||||
|
sample_config = config_cls.sample_run_config()
|
||||||
|
expanded_config = replace_env_vars(sample_config)
|
||||||
|
config_from_sample = config_cls(**{**expanded_config, **extra_config})
|
||||||
|
assert config_from_sample.auth_credential is not None
|
||||||
|
assert config_from_sample.auth_credential.get_secret_value() == secret_value
|
Loading…
Add table
Add a link
Reference in a new issue