mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
fix: prevent telemetry from leaking sensitive info
Prevent sensitive information from being logged in telemetry output by assigning SecretStr type to sensitive fields. API keys, password from KV store are now covered. All providers have been converted. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
8dc9fd6844
commit
c4cb6aa8d9
53 changed files with 121 additions and 109 deletions
|
@ -17,7 +17,7 @@ AWS S3-based file storage provider for scalable cloud file management with metad
|
||||||
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
||||||
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
||||||
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
| `aws_secret_access_key` | `pydantic.types.SecretStr \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
||||||
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
||||||
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
||||||
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
||||||
|
|
|
@ -14,7 +14,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | API key for Anthropic models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API key for Anthropic models |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
| `aws_secret_access_key` | `pydantic.types.SecretStr \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
| `aws_session_token` | `pydantic.types.SecretStr \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||||
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
||||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||||
|
|
|
@ -14,7 +14,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | API key for Gemini models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API key for Gemini models |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | The Groq API key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Groq API key |
|
||||||
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -14,7 +14,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | The Llama API key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Llama API key |
|
||||||
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -14,7 +14,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API key for OpenAI models |
|
||||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -15,7 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||||
| `api_token` | `str \| None` | No | | The API token |
|
| `api_token` | `pydantic.types.SecretStr \| None` | No | | The API token |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||||
| `api_token` | `str \| None` | No | fake | The API token |
|
| `api_token` | `pydantic.types.SecretStr \| None` | No | ********** | The API token |
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
|
||||||
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
| `aws_secret_access_key` | `pydantic.types.SecretStr \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
|
||||||
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
| `aws_session_token` | `pydantic.types.SecretStr \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
|
||||||
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
|
||||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||||
|
|
|
@ -14,7 +14,7 @@ Braintrust scoring provider for evaluation and scoring using the Braintrust plat
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `openai_api_key` | `str \| None` | No | | The OpenAI API Key |
|
| `openai_api_key` | `pydantic.types.SecretStr \| None` | No | | The OpenAI API Key |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ Bing Search tool for web search capabilities using Microsoft's search engine.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | |
|
||||||
| `top_k` | `<class 'int'>` | No | 3 | |
|
| `top_k` | `<class 'int'>` | No | 3 | |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -14,7 +14,7 @@ Brave Search tool for web search capabilities with privacy-focused results.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | The Brave Search API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Brave Search API Key |
|
||||||
| `max_results` | `<class 'int'>` | No | 3 | The maximum number of results to return |
|
| `max_results` | `<class 'int'>` | No | 3 | The maximum number of results to return |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -14,7 +14,7 @@ Tavily Search tool for AI-optimized web search with structured results.
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | The Tavily Search API Key |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Tavily Search API Key |
|
||||||
| `max_results` | `<class 'int'>` | No | 3 | The maximum number of results to return |
|
| `max_results` | `<class 'int'>` | No | 3 | The maximum number of results to return |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -217,7 +217,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
| `port` | `int \| None` | No | 5432 | |
|
| `port` | `int \| None` | No | 5432 | |
|
||||||
| `db` | `str \| None` | No | postgres | |
|
| `db` | `str \| None` | No | postgres | |
|
||||||
| `user` | `str \| None` | No | postgres | |
|
| `user` | `str \| None` | No | postgres | |
|
||||||
| `password` | `str \| None` | No | mysecretpassword | |
|
| `password` | `pydantic.types.SecretStr \| None` | No | mysecretpassword | |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from llama_stack.core.datatypes import Api
|
from llama_stack.core.datatypes import Api
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
class BraintrustProviderDataValidator(BaseModel):
|
class BraintrustProviderDataValidator(BaseModel):
|
||||||
openai_api_key: str
|
openai_api_key: SecretStr
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
|
|
@ -17,7 +17,7 @@ from autoevals.ragas import (
|
||||||
ContextRelevancy,
|
ContextRelevancy,
|
||||||
Faithfulness,
|
Faithfulness,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
@ -152,9 +152,9 @@ class BraintrustScoringImpl(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
|
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
self.config.openai_api_key = provider_data.openai_api_key
|
self.config.openai_api_key = SecretStr(provider_data.openai_api_key)
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key.get_secret_value()
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -5,11 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringConfig(BaseModel):
|
class BraintrustScoringConfig(BaseModel):
|
||||||
openai_api_key: str | None = Field(
|
openai_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The OpenAI API Key",
|
description="The OpenAI API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,7 +64,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
for key, value in event.attributes.items():
|
for key, value in event.attributes.items():
|
||||||
if key.startswith("__") or key in ["message", "severity"]:
|
if key.startswith("__") or key in ["message", "severity"]:
|
||||||
continue
|
continue
|
||||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
|
||||||
|
str_value = str(value)
|
||||||
|
logger.info(f"[dim]{key}[/dim]: {str_value}")
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the processor."""
|
"""Shutdown the processor."""
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class S3FilesImplConfig(BaseModel):
|
||||||
bucket_name: str = Field(description="S3 bucket name to store files")
|
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||||
aws_secret_access_key: str | None = Field(
|
aws_secret_access_key: SecretStr | None = Field(
|
||||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||||
)
|
)
|
||||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||||
|
|
|
@ -46,7 +46,7 @@ def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||||
s3_config.update(
|
s3_config.update(
|
||||||
{
|
{
|
||||||
"aws_access_key_id": config.aws_access_key_id,
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
"aws_secret_access_key": config.aws_secret_access_key,
|
"aws_secret_access_key": config.aws_secret_access_key.get_secret_value(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="anthropic",
|
litellm_provider_name="anthropic",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||||
provider_data_api_key_field="anthropic_api_key",
|
provider_data_api_key_field="anthropic_api_key",
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class AnthropicProviderDataValidator(BaseModel):
|
class AnthropicProviderDataValidator(BaseModel):
|
||||||
anthropic_api_key: str | None = Field(
|
anthropic_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for Anthropic models",
|
description="API key for Anthropic models",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AnthropicConfig(BaseModel):
|
class AnthropicConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for Anthropic models",
|
description="API key for Anthropic models",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class GeminiProviderDataValidator(BaseModel):
|
class GeminiProviderDataValidator(BaseModel):
|
||||||
gemini_api_key: str | None = Field(
|
gemini_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for Gemini models",
|
description="API key for Gemini models",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GeminiConfig(BaseModel):
|
class GeminiConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for Gemini models",
|
description="API key for Gemini models",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="gemini",
|
litellm_provider_name="gemini",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||||
provider_data_api_key_field="gemini_api_key",
|
provider_data_api_key_field="gemini_api_key",
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class GroqProviderDataValidator(BaseModel):
|
class GroqProviderDataValidator(BaseModel):
|
||||||
groq_api_key: str | None = Field(
|
groq_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for Groq models",
|
description="API key for Groq models",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class GroqProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GroqConfig(BaseModel):
|
class GroqConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
# The Groq client library loads the GROQ_API_KEY environment variable by default
|
||||||
default=None,
|
default=None,
|
||||||
description="The Groq API key",
|
description="The Groq API key",
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class LlamaProviderDataValidator(BaseModel):
|
class LlamaProviderDataValidator(BaseModel):
|
||||||
llama_api_key: str | None = Field(
|
llama_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for api.llama models",
|
description="API key for api.llama models",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LlamaCompatConfig(BaseModel):
|
class LlamaCompatConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Llama API key",
|
description="The Llama API key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProviderDataValidator(BaseModel):
|
class OpenAIProviderDataValidator(BaseModel):
|
||||||
openai_api_key: str | None = Field(
|
openai_api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for OpenAI models",
|
description="API key for OpenAI models",
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIConfig(BaseModel):
|
class OpenAIConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for OpenAI models",
|
description="API key for OpenAI models",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class RunpodImplConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Runpod model serving endpoint",
|
description="The URL for the Runpod model serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: str | None = Field(
|
api_token: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
|
|
|
@ -103,7 +103,10 @@ class RunpodInferenceAdapter(
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
client = OpenAI(
|
||||||
|
base_url=self.config.url,
|
||||||
|
api_key=self.config.api_token.get_secret_value() if self.config.api_token else None,
|
||||||
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
import google.auth.transport.requests
|
import google.auth.transport.requests
|
||||||
from google.auth import default
|
from google.auth import default
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest
|
from llama_stack.apis.inference import ChatCompletionRequest
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||||
|
@ -43,7 +44,7 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
except Exception:
|
except Exception:
|
||||||
# If we can't get credentials, return empty string to let LiteLLM handle it
|
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||||
# This allows the LiteLLM mixin to work with ADC directly
|
# This allows the LiteLLM mixin to work with ADC directly
|
||||||
return ""
|
return SecretStr("")
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
default=4096,
|
default=4096,
|
||||||
description="Maximum number of tokens to generate.",
|
description="Maximum number of tokens to generate.",
|
||||||
)
|
)
|
||||||
api_token: str | None = Field(
|
api_token: SecretStr | None = Field(
|
||||||
default="fake",
|
default=SecretStr("fake"),
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
tls_verify: bool | str = Field(
|
tls_verify: bool | str = Field(
|
||||||
|
|
|
@ -294,7 +294,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
||||||
self,
|
self,
|
||||||
model_entries=build_hf_repo_model_entries(),
|
model_entries=build_hf_repo_model_entries(),
|
||||||
litellm_provider_name="vllm",
|
litellm_provider_name="vllm",
|
||||||
api_key_from_config=config.api_token,
|
api_key_from_config=config.api_token.get_secret_value(),
|
||||||
provider_data_api_key_field="vllm_api_token",
|
provider_data_api_key_field="vllm_api_token",
|
||||||
openai_compat_api_base=config.url,
|
openai_compat_api_base=config.url,
|
||||||
)
|
)
|
||||||
|
|
|
@ -40,7 +40,7 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key:
|
if self.config.api_key:
|
||||||
return self.config.api_key
|
return self.config.api_key.get_secret_value()
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.bing_search_api_key:
|
if provider_data is None or not provider_data.bing_search_api_key:
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolConfig(BaseModel):
|
class BingSearchToolConfig(BaseModel):
|
||||||
"""Configuration for Bing Search Tool Runtime"""
|
"""Configuration for Bing Search Tool Runtime"""
|
||||||
|
|
||||||
api_key: str | None = None
|
api_key: SecretStr | None = None
|
||||||
top_k: int = 3
|
top_k: int = 3
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -39,7 +39,7 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key:
|
if self.config.api_key:
|
||||||
return self.config.api_key
|
return self.config.api_key.get_secret_value()
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.brave_search_api_key:
|
if provider_data is None or not provider_data.brave_search_api_key:
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolConfig(BaseModel):
|
class BraveSearchToolConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Brave Search API Key",
|
description="The Brave Search API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolConfig(BaseModel):
|
class TavilySearchToolConfig(BaseModel):
|
||||||
api_key: str | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Tavily Search API Key",
|
description="The Tavily Search API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -39,7 +39,7 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key:
|
if self.config.api_key:
|
||||||
return self.config.api_key
|
return self.config.api_key.get_secret_value()
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.tavily_search_api_key:
|
if provider_data is None or not provider_data.tavily_search_api_key:
|
||||||
|
|
|
@ -40,7 +40,7 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
if self.config.api_key:
|
if self.config.api_key:
|
||||||
return self.config.api_key
|
return self.config.api_key.get_secret_value()
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.wolfram_alpha_api_key:
|
if provider_data is None or not provider_data.wolfram_alpha_api_key:
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
KVStoreConfig,
|
KVStoreConfig,
|
||||||
|
@ -21,7 +21,7 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
port: int | None = Field(default=5432)
|
port: int | None = Field(default=5432)
|
||||||
db: str | None = Field(default="postgres")
|
db: str | None = Field(default="postgres")
|
||||||
user: str | None = Field(default="postgres")
|
user: str | None = Field(default="postgres")
|
||||||
password: str | None = Field(default="mysecretpassword")
|
password: SecretStr | None = Field(default="mysecretpassword")
|
||||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -366,7 +366,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
port=self.config.port,
|
port=self.config.port,
|
||||||
database=self.config.db,
|
database=self.config.db,
|
||||||
user=self.config.user,
|
user=self.config.user,
|
||||||
password=self.config.password,
|
password=self.config.password.get_secret_value(),
|
||||||
)
|
)
|
||||||
self.conn.autocommit = True
|
self.conn.autocommit = True
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
|
|
|
@ -50,8 +50,8 @@ def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedroc
|
||||||
|
|
||||||
session_args = {
|
session_args = {
|
||||||
"aws_access_key_id": config.aws_access_key_id,
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
"aws_secret_access_key": config.aws_secret_access_key,
|
"aws_secret_access_key": config.aws_secret_access_key.get_secret_value(),
|
||||||
"aws_session_token": config.aws_session_token,
|
"aws_session_token": config.aws_session_token.get_secret_value(),
|
||||||
"region_name": config.region_name,
|
"region_name": config.region_name,
|
||||||
"profile_name": config.profile_name,
|
"profile_name": config.profile_name,
|
||||||
"session_ttl": config.session_ttl,
|
"session_ttl": config.session_ttl,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class BedrockBaseConfig(BaseModel):
|
class BedrockBaseConfig(BaseModel):
|
||||||
|
@ -14,12 +14,12 @@ class BedrockBaseConfig(BaseModel):
|
||||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
)
|
)
|
||||||
aws_secret_access_key: str | None = Field(
|
aws_secret_access_key: SecretStr | None = Field(
|
||||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
default_factory=lambda: SecretStr(val) if (val := os.getenv("AWS_SECRET_ACCESS_KEY")) else None,
|
||||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||||
)
|
)
|
||||||
aws_session_token: str | None = Field(
|
aws_session_token: SecretStr | None = Field(
|
||||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
default_factory=lambda: SecretStr(val) if (val := os.getenv("AWS_SESSION_TOKEN")) else None,
|
||||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||||
)
|
)
|
||||||
region_name: str | None = Field(
|
region_name: str | None = Field(
|
||||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -68,7 +69,7 @@ class LiteLLMOpenAIMixin(
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name: str,
|
litellm_provider_name: str,
|
||||||
api_key_from_config: str | None,
|
api_key_from_config: SecretStr | None,
|
||||||
provider_data_api_key_field: str,
|
provider_data_api_key_field: str,
|
||||||
model_entries: list[ProviderModelEntry] | None = None,
|
model_entries: list[ProviderModelEntry] | None = None,
|
||||||
openai_compat_api_base: str | None = None,
|
openai_compat_api_base: str | None = None,
|
||||||
|
@ -247,14 +248,14 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"api_key": self.get_api_key(),
|
"api_key": self.get_api_key().get_secret_value(),
|
||||||
"api_base": self.api_base,
|
"api_base": self.api_base,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> SecretStr:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
if provider_data and getattr(provider_data, key_field, None):
|
if provider_data and getattr(provider_data, key_field, None):
|
||||||
|
@ -305,7 +306,7 @@ class LiteLLMOpenAIMixin(
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
input=input_list,
|
input=input_list,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key().get_secret_value(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
)
|
)
|
||||||
|
@ -368,7 +369,7 @@ class LiteLLMOpenAIMixin(
|
||||||
user=user,
|
user=user,
|
||||||
guided_choice=guided_choice,
|
guided_choice=guided_choice,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key().get_secret_value(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.atext_completion(**params)
|
return await litellm.atext_completion(**params)
|
||||||
|
@ -424,7 +425,7 @@ class LiteLLMOpenAIMixin(
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key().get_secret_value(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.acompletion(**params)
|
return await litellm.acompletion(**params)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import NOT_GIVEN, AsyncOpenAI
|
from openai import NOT_GIVEN, AsyncOpenAI
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Model,
|
Model,
|
||||||
|
@ -70,14 +71,14 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||||
allowed_models: list[str] = []
|
allowed_models: list[str] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> SecretStr:
|
||||||
"""
|
"""
|
||||||
Get the API key.
|
Get the API key.
|
||||||
|
|
||||||
This method must be implemented by child classes to provide the API key
|
This method must be implemented by child classes to provide the API key
|
||||||
for authenticating with the OpenAI API or compatible endpoints.
|
for authenticating with the OpenAI API or compatible endpoints.
|
||||||
|
|
||||||
:return: The API key as a string
|
:return: The API key as a SecretStr
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||||
|
|
||||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class PostgresKVStoreConfig(CommonConfig):
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
user: str
|
user: str
|
||||||
password: str | None = None
|
password: SecretStr | None = None
|
||||||
ssl_mode: str | None = None
|
ssl_mode: str | None = None
|
||||||
ca_cert_path: str | None = None
|
ca_cert_path: str | None = None
|
||||||
table_name: str = "llamastack_kvstore"
|
table_name: str = "llamastack_kvstore"
|
||||||
|
@ -118,7 +118,7 @@ class MongoDBKVStoreConfig(CommonConfig):
|
||||||
port: int = 27017
|
port: int = 27017
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
password: str | None = None
|
password: SecretStr | None = None
|
||||||
collection_name: str = "llamastack_kvstore"
|
collection_name: str = "llamastack_kvstore"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -34,7 +34,7 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
"host": self.config.host,
|
"host": self.config.host,
|
||||||
"port": self.config.port,
|
"port": self.config.port,
|
||||||
"username": self.config.user,
|
"username": self.config.user,
|
||||||
"password": self.config.password,
|
"password": self.config.password.get_secret_value(),
|
||||||
}
|
}
|
||||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||||
self.conn = AsyncMongoClient(**conn_creds)
|
self.conn = AsyncMongoClient(**conn_creds)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class PostgresKVStoreImpl(KVStore):
|
||||||
port=self.config.port,
|
port=self.config.port,
|
||||||
database=self.config.db,
|
database=self.config.db,
|
||||||
user=self.config.user,
|
user=self.config.user,
|
||||||
password=self.config.password,
|
password=self.config.password.get_secret_value(),
|
||||||
sslmode=self.config.ssl_mode,
|
sslmode=self.config.ssl_mode,
|
||||||
sslrootcert=self.config.ca_cert_path,
|
sslrootcert=self.config.ca_cert_path,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
|
@ -63,11 +63,11 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
user: str
|
user: str
|
||||||
password: str | None = None
|
password: SecretStr | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def engine_str(self) -> str:
|
def engine_str(self) -> str:
|
||||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
return f"postgresql+asyncpg://{self.user}:{self.password.get_secret_value() if self.password else ''}@{self.host}:{self.port}/{self.db}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def pip_packages(cls) -> list[str]:
|
def pip_packages(cls) -> list[str]:
|
||||||
|
|
|
@ -33,7 +33,7 @@ def test_groq_provider_openai_client_caching():
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
assert inference_adapter.client.api_key == api_key
|
assert inference_adapter.client.api_key.get_secret_value() == api_key
|
||||||
|
|
||||||
|
|
||||||
def test_openai_provider_openai_client_caching():
|
def test_openai_provider_openai_client_caching():
|
||||||
|
@ -52,7 +52,7 @@ def test_openai_provider_openai_client_caching():
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
openai_client = inference_adapter.client
|
openai_client = inference_adapter.client
|
||||||
assert openai_client.api_key == api_key
|
assert openai_client.api_key.get_secret_value() == api_key
|
||||||
|
|
||||||
|
|
||||||
def test_together_provider_openai_client_caching():
|
def test_together_provider_openai_client_caching():
|
||||||
|
@ -86,4 +86,4 @@ def test_llama_compat_provider_openai_client_caching():
|
||||||
|
|
||||||
for api_key in ["test1", "test2"]:
|
for api_key in ["test1", "test2"]:
|
||||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
||||||
assert inference_adapter.client.api_key == api_key
|
assert inference_adapter.client.api_key.get_secret_value() == api_key
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.core.request_headers import request_provider_data_context
|
from llama_stack.core.request_headers import request_provider_data_context
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
@ -16,11 +16,11 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
||||||
|
|
||||||
# Test fixtures and helper classes
|
# Test fixtures and helper classes
|
||||||
class TestConfig(BaseModel):
|
class TestConfig(BaseModel):
|
||||||
api_key: str | None = Field(default=None)
|
api_key: SecretStr | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class TestProviderDataValidator(BaseModel):
|
class TestProviderDataValidator(BaseModel):
|
||||||
test_api_key: str | None = Field(default=None)
|
test_api_key: SecretStr | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||||
|
@ -36,7 +36,7 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def adapter_with_config_key():
|
def adapter_with_config_key():
|
||||||
"""Fixture to create adapter with API key in config"""
|
"""Fixture to create adapter with API key in config"""
|
||||||
config = TestConfig(api_key="config-api-key")
|
config = TestConfig(api_key=SecretStr("config-api-key"))
|
||||||
adapter = TestLiteLLMAdapter(config)
|
adapter = TestLiteLLMAdapter(config)
|
||||||
adapter.__provider_spec__ = MagicMock()
|
adapter.__provider_spec__ = MagicMock()
|
||||||
adapter.__provider_spec__.provider_data_validator = (
|
adapter.__provider_spec__.provider_data_validator = (
|
||||||
|
@ -59,7 +59,7 @@ def adapter_without_config_key():
|
||||||
|
|
||||||
def test_api_key_from_config_when_no_provider_data(adapter_with_config_key):
|
def test_api_key_from_config_when_no_provider_data(adapter_with_config_key):
|
||||||
"""Test that adapter uses config API key when no provider data is available"""
|
"""Test that adapter uses config API key when no provider data is available"""
|
||||||
api_key = adapter_with_config_key.get_api_key()
|
api_key = adapter_with_config_key.get_api_key().get_secret_value()
|
||||||
assert api_key == "config-api-key"
|
assert api_key == "config-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,28 +68,28 @@ def test_provider_data_takes_priority_over_config(adapter_with_config_key):
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
|
||||||
):
|
):
|
||||||
api_key = adapter_with_config_key.get_api_key()
|
api_key = adapter_with_config_key.get_api_key().get_secret_value()
|
||||||
assert api_key == "provider-data-key"
|
assert api_key == "provider-data-key"
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key):
|
def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key):
|
||||||
"""Test fallback to config when provider data doesn't have the required key"""
|
"""Test fallback to config when provider data doesn't have the required key"""
|
||||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||||
api_key = adapter_with_config_key.get_api_key()
|
api_key = adapter_with_config_key.get_api_key().get_secret_value()
|
||||||
assert api_key == "config-api-key"
|
assert api_key == "config-api-key"
|
||||||
|
|
||||||
|
|
||||||
def test_error_when_no_api_key_available(adapter_without_config_key):
|
def test_error_when_no_api_key_available(adapter_without_config_key):
|
||||||
"""Test that ValueError is raised when neither config nor provider data have API key"""
|
"""Test that ValueError is raised when neither config nor provider data have API key"""
|
||||||
with pytest.raises(ValueError, match="API key is not set"):
|
with pytest.raises(ValueError, match="API key is not set"):
|
||||||
adapter_without_config_key.get_api_key()
|
adapter_without_config_key.get_api_key().get_secret_value()
|
||||||
|
|
||||||
|
|
||||||
def test_error_when_provider_data_has_wrong_key(adapter_without_config_key):
|
def test_error_when_provider_data_has_wrong_key(adapter_without_config_key):
|
||||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||||
with pytest.raises(ValueError, match="API key is not set"):
|
with pytest.raises(ValueError, match="API key is not set"):
|
||||||
adapter_without_config_key.get_api_key()
|
adapter_without_config_key.get_api_key().get_secret_value()
|
||||||
|
|
||||||
|
|
||||||
def test_provider_data_works_when_config_is_none(adapter_without_config_key):
|
def test_provider_data_works_when_config_is_none(adapter_without_config_key):
|
||||||
|
@ -97,14 +97,14 @@ def test_provider_data_works_when_config_is_none(adapter_without_config_key):
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})}
|
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})}
|
||||||
):
|
):
|
||||||
api_key = adapter_without_config_key.get_api_key()
|
api_key = adapter_without_config_key.get_api_key().get_secret_value()
|
||||||
assert api_key == "provider-only-key"
|
assert api_key == "provider-only-key"
|
||||||
|
|
||||||
|
|
||||||
def test_error_message_includes_correct_field_names(adapter_without_config_key):
|
def test_error_message_includes_correct_field_names(adapter_without_config_key):
|
||||||
"""Test that error message includes correct field name and header information"""
|
"""Test that error message includes correct field name and header information"""
|
||||||
try:
|
try:
|
||||||
adapter_without_config_key.get_api_key()
|
adapter_without_config_key.get_api_key().get_secret_value()
|
||||||
raise AssertionError("Should have raised ValueError")
|
raise AssertionError("Should have raised ValueError")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
assert "test_api_key" in str(e) # Should mention the correct field name
|
assert "test_api_key" in str(e) # Should mention the correct field name
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from llama_stack.core.stack import replace_env_vars
|
from llama_stack.core.stack import replace_env_vars
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||||
|
@ -59,14 +61,14 @@ class TestOpenAIBaseURLConfig:
|
||||||
adapter = OpenAIInferenceAdapter(config)
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||||
|
|
||||||
# Access the client property to trigger AsyncOpenAI initialization
|
# Access the client property to trigger AsyncOpenAI initialization
|
||||||
_ = adapter.client
|
_ = adapter.client
|
||||||
|
|
||||||
# Verify AsyncOpenAI was called with the correct base_url
|
# Verify AsyncOpenAI was called with the correct base_url
|
||||||
mock_openai_class.assert_called_once_with(
|
mock_openai_class.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key=SecretStr("test-key"),
|
||||||
base_url=custom_url,
|
base_url=custom_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,7 +80,7 @@ class TestOpenAIBaseURLConfig:
|
||||||
adapter = OpenAIInferenceAdapter(config)
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
# Mock the get_api_key method
|
# Mock the get_api_key method
|
||||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||||
|
|
||||||
# Mock a model object that will be returned by models.list()
|
# Mock a model object that will be returned by models.list()
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
@ -101,7 +103,7 @@ class TestOpenAIBaseURLConfig:
|
||||||
|
|
||||||
# Verify the client was created with the custom URL
|
# Verify the client was created with the custom URL
|
||||||
mock_openai_class.assert_called_with(
|
mock_openai_class.assert_called_with(
|
||||||
api_key="test-key",
|
api_key=SecretStr("test-key"),
|
||||||
base_url=custom_url,
|
base_url=custom_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -119,7 +121,7 @@ class TestOpenAIBaseURLConfig:
|
||||||
adapter = OpenAIInferenceAdapter(config)
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
# Mock the get_api_key method
|
# Mock the get_api_key method
|
||||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
adapter.get_api_key = MagicMock(return_value=SecretStr("test-key"))
|
||||||
|
|
||||||
# Mock a model object that will be returned by models.list()
|
# Mock a model object that will be returned by models.list()
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
@ -142,6 +144,6 @@ class TestOpenAIBaseURLConfig:
|
||||||
|
|
||||||
# Verify the client was created with the environment variable URL
|
# Verify the client was created with the environment variable URL
|
||||||
mock_openai_class.assert_called_with(
|
mock_openai_class.assert_called_with(
|
||||||
api_key="test-key",
|
api_key=SecretStr("test-key"),
|
||||||
base_url="https://proxy.openai.com/v1",
|
base_url="https://proxy.openai.com/v1",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue