Merge pull request #5489 from BerriAI/litellm_Add_secret_managers

[Feat] Add Google Secret Manager Support
This commit is contained in:
Ishaan Jaff 2024-09-03 14:51:32 -07:00 committed by GitHub
commit 19dbfff620
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 633 additions and 327 deletions

View file

@ -23,7 +23,7 @@ This covers:
- ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs) - ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs)
- ✅ [JWT-Auth](../docs/proxy/token_auth.md) - ✅ [JWT-Auth](../docs/proxy/token_auth.md)
- ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes) - ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes)
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption) - ✅ [**Secret Managers** AWS Key Manager, Google Secret Manager, Azure Key](./secret)
- ✅ IP addressbased access control lists - ✅ IP addressbased access control lists
- ✅ Track Request IP Address - ✅ Track Request IP Address
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)

View file

@ -17,7 +17,7 @@ Features:
- ✅ [Audit Logs with retention policy](#audit-logs) - ✅ [Audit Logs with retention policy](#audit-logs)
- ✅ [JWT-Auth](../docs/proxy/token_auth.md) - ✅ [JWT-Auth](../docs/proxy/token_auth.md)
- ✅ [Control available public, private routes](#control-available-public-private-routes) - ✅ [Control available public, private routes](#control-available-public-private-routes)
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption) - ✅ [**Secret Managers** AWS Key Manager, Google Secret Manager, Azure Key](../secret)
- ✅ IP addressbased access control lists - ✅ IP addressbased access control lists
- ✅ Track Request IP Address - ✅ Track Request IP Address
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints) - ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)

View file

@ -1,9 +1,22 @@
# Secret Manager # Secret Manager
LiteLLM supports reading secrets from Azure Key Vault and Infisical LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
- AWS Key Managemenet Service :::info
✨ **This is an Enterprise Feature**
[Enterprise Pricing](https://www.litellm.ai/#pricing)
[Contact us here to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## Supported Secret Managers
- AWS Key Management Service
- AWS Secret Manager - AWS Secret Manager
- [Azure Key Vault](#azure-key-vault) - [Azure Key Vault](#azure-key-vault)
- [Google Secret Manager](#google-secret-manager)
- Google Key Management Service - Google Key Management Service
- [Infisical Secret Manager](#infisical-secret-manager) - [Infisical Secret Manager](#infisical-secret-manager)
- [.env Files](#env-files) - [.env Files](#env-files)
@ -125,6 +138,45 @@ litellm --config /path/to/config.yaml
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js) [Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
## Google Secret Manager
Support for [Google Secret Manager](https://cloud.google.com/security/products/secret-manager)
1. Save Google Secret Manager details in your environment
```shell
GOOGLE_SECRET_MANAGER_PROJECT_ID="your-project-id-on-gcp" # example: adroit-crow-413218
```
Optional Params
```shell
export GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL = "" # (int) defaults to 86400
export GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER = "" # (str) set to "true" if you want to always read from google secret manager without using in memory caching. NOT RECOMMENDED in PROD
```
2. Add to proxy config.yaml
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_base: https://exampleopenaiendpoint-production.up.railway.app/
api_key: os.environ/OPENAI_API_KEY # this will be read from Google Secret Manager
general_settings:
key_management_system: "google_secret_manager"
```
You can now test this by starting your proxy:
```bash
litellm --config /path/to/config.yaml
```
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
## Google Key Management Service ## Google Key Management Service
Use encrypted keys from Google KMS on the proxy Use encrypted keys from Google KMS on the proxy

View file

@ -834,7 +834,6 @@ from .utils import (
decode, decode,
_calculate_retry_after, _calculate_retry_after,
_should_retry, _should_retry,
get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base, get_api_base,
get_first_chars_messages, get_first_chars_messages,

View file

@ -22,6 +22,7 @@ import litellm
from litellm import client from litellm import client
from litellm.llms.azure import AzureBatchesAPI from litellm.llms.azure import AzureBatchesAPI
from litellm.llms.openai import OpenAIBatchesAPI from litellm.llms.openai import OpenAIBatchesAPI
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
Batch, Batch,
CancelBatchRequest, CancelBatchRequest,
@ -34,7 +35,7 @@ from litellm.types.llms.openai import (
RetrieveBatchRequest, RetrieveBatchRequest,
) )
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.utils import get_secret, supports_httpx_timeout from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI() openai_batches_instance = OpenAIBatchesAPI()

View file

@ -17,7 +17,6 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm import get_secret
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
from litellm.llms.fine_tuning_apis.openai import ( from litellm.llms.fine_tuning_apis.openai import (
@ -26,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
OpenAIFineTuningAPI, OpenAIFineTuningAPI,
) )
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import Hyperparameters from litellm.types.llms.openai import Hyperparameters
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import supports_httpx_timeout

View file

@ -52,6 +52,27 @@ class GCSBucketBase(CustomLogger):
return headers return headers
def sync_construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
vertex_credentials=self.path_service_account_json,
vertex_project=None,
vertex_location=None,
gemini_api_key=None,
stream=None,
custom_llm_provider="vertex_ai",
api_base=None,
)
verbose_logger.debug("constructed auth_header %s", auth_header)
headers = {
"Authorization": f"Bearer {auth_header}", # auth_header
"Content-Type": "application/json",
}
return headers
async def download_gcs_object(self, object_name): async def download_gcs_object(self, object_name):
""" """
Download an object from GCS. Download an object from GCS.

View file

@ -5,7 +5,7 @@ import httpx
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.caching import DualCache, InMemoryCache from litellm.caching import DualCache, InMemoryCache
from litellm.utils import get_secret from litellm.secret_managers.main import get_secret
from .base import BaseLLM from .base import BaseLLM

View file

@ -10,7 +10,7 @@ from typing import List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm import get_secret from litellm.secret_managers.main import get_secret
class BedrockError(Exception): class BedrockError(Exception):

View file

@ -11,7 +11,6 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx import httpx
import litellm import litellm
from litellm import get_secret
from litellm.llms.cohere.embed import embedding as cohere_embedding from litellm.llms.cohere.embed import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -19,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client, _get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
) )
from litellm.secret_managers.main import get_secret
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
from litellm.types.utils import Embedding, EmbeddingResponse, Usage from litellm.types.utils import Embedding, EmbeddingResponse, Usage

View file

@ -1063,6 +1063,7 @@ class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms" GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager" AWS_SECRET_MANAGER = "aws_secret_manager"
GOOGLE_SECRET_MANAGER = "google_secret_manager"
LOCAL = "local" LOCAL = "local"
AWS_KMS = "aws_kms" AWS_KMS = "aws_kms"

View file

@ -14,7 +14,7 @@ def init_rds_client(
aws_web_identity_token: Optional[str] = None, aws_web_identity_token: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
from litellm import get_secret from litellm.secret_managers.main import get_secret
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)

View file

@ -1,16 +1,17 @@
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm.proxy.utils import hash_token
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_Config,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
)
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.utils import hash_token
from litellm.secret_managers.main import get_secret
class DynamoDBWrapper(CustomDB): class DynamoDBWrapper(CustomDB):
@ -21,19 +22,19 @@ class DynamoDBWrapper(CustomDB):
def __init__(self, database_arguments: DynamoDBArgs): def __init__(self, database_arguments: DynamoDBArgs):
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import ( from aiodynamo.models import (
Throughput,
KeySchema, KeySchema,
KeySpec, KeySpec,
KeyType, KeyType,
PayPerRequest, PayPerRequest,
ReturnValues,
Throughput,
) )
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
from yarl import URL
self.throughput_type = None self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST": if database_arguments.billing_mode == "PAY_PER_REQUEST":
@ -59,7 +60,9 @@ class DynamoDBWrapper(CustomDB):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}" f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
) )
import boto3, os import os
import boto3
sts_client = boto3.client("sts") sts_client = boto3.client("sts")
@ -92,22 +95,22 @@ class DynamoDBWrapper(CustomDB):
""" """
Connect to DB, and creating / updating any tables Connect to DB, and creating / updating any tables
""" """
import aiohttp
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import ( from aiodynamo.models import (
Throughput,
KeySchema, KeySchema,
KeySpec, KeySpec,
KeyType, KeyType,
PayPerRequest, PayPerRequest,
ReturnValues,
Throughput,
) )
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
import aiohttp from yarl import URL
verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect") verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect")
self.set_env_vars_based_on_arn() self.set_env_vars_based_on_arn()
@ -192,22 +195,22 @@ class DynamoDBWrapper(CustomDB):
async def insert_data( async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config", "spend"] self, value: Any, table_name: Literal["user", "key", "config", "spend"]
): ):
import aiohttp
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import ( from aiodynamo.models import (
Throughput,
KeySchema, KeySchema,
KeySpec, KeySpec,
KeyType, KeyType,
PayPerRequest, PayPerRequest,
ReturnValues,
Throughput,
) )
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
import aiohttp from yarl import URL
self.set_env_vars_based_on_arn() self.set_env_vars_based_on_arn()
@ -237,22 +240,22 @@ class DynamoDBWrapper(CustomDB):
return await table.put_item(item=value, return_values=ReturnValues.all_old) return await table.put_item(item=value, return_values=ReturnValues.all_old)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
import aiohttp
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import ( from aiodynamo.models import (
Throughput,
KeySchema, KeySchema,
KeySpec, KeySpec,
KeyType, KeyType,
PayPerRequest, PayPerRequest,
ReturnValues,
Throughput,
) )
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
import aiohttp from yarl import URL
self.set_env_vars_based_on_arn() self.set_env_vars_based_on_arn()
@ -311,22 +314,22 @@ class DynamoDBWrapper(CustomDB):
self, key: str, value: dict, table_name: Literal["user", "key", "config"] self, key: str, value: dict, table_name: Literal["user", "key", "config"]
): ):
self.set_env_vars_based_on_arn() self.set_env_vars_based_on_arn()
import aiohttp
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import ( from aiodynamo.models import (
Throughput,
KeySchema, KeySchema,
KeySpec, KeySpec,
KeyType, KeyType,
PayPerRequest, PayPerRequest,
ReturnValues,
Throughput,
) )
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
import aiohttp from yarl import URL
if self.database_arguments.ssl_verify == False: if self.database_arguments.ssl_verify == False:
client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False)) client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False))

View file

@ -24,7 +24,6 @@ import httpx
from fastapi import HTTPException from fastapi import HTTPException
import litellm import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
@ -38,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import (
) )
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import ( from litellm.types.guardrails import (
BedrockContentItem, BedrockContentItem,
BedrockRequest, BedrockRequest,

View file

@ -19,12 +19,12 @@ import httpx
from fastapi import HTTPException from fastapi import HTTPException
import litellm import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import ( from litellm.types.guardrails import (
GuardrailItem, GuardrailItem,
LakeraCategoryThresholds, LakeraCategoryThresholds,

View file

@ -9,7 +9,7 @@ import time
sys.path.insert( sys.path.insert(
0, os.path.abspath("./") 0, os.path.abspath("./")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var from litellm.secret_managers.aws_secret_manager import decrypt_env_var
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV

View file

@ -475,7 +475,7 @@ def run_server(
### DECRYPT ENV VAR ### ### DECRYPT ENV VAR ###
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var from litellm.secret_managers.aws_secret_manager import decrypt_env_var
if ( if (
os.getenv("USE_AWS_KMS", None) is not None os.getenv("USE_AWS_KMS", None) is not None
@ -544,6 +544,15 @@ def run_server(
load_aws_secret_manager(use_aws_secret_manager=True) load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif (
key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
):
from litellm.secret_managers.google_secret_manager import (
GoogleSecretManager,
)
GoogleSecretManager()
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get( key_management_settings = general_settings.get(
@ -598,7 +607,7 @@ def run_server(
or os.getenv("DIRECT_URL", None) is not None or os.getenv("DIRECT_URL", None) is not None
): ):
try: try:
from litellm import get_secret from litellm.secret_managers.main import get_secret
if os.getenv("DATABASE_URL", None) is not None: if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args ### add connection pool + pool timeout args

View file

@ -1,15 +1,20 @@
model_list: model_list:
- model_name: fake-openai-endpoint - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: openai/fake model: openai/fake
api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info: api_key: os.environ/OPENAI_API_KEY
id: "team-a-model" # used for identifying model in response headers - model_name: gpt-3.5-turbo-end-user-test
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY
litellm_settings: litellm_settings:
success_callback: ["prometheus"] success_callback: ["prometheus"]
failure_callback: ["prometheus"] failure_callback: ["prometheus"]
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth key_management_system: "google_secret_manager"

View file

@ -212,11 +212,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
) )
from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
from litellm.proxy.route_llm_request import route_request from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.spend_tracking.spend_management_endpoints import ( from litellm.proxy.spend_tracking.spend_management_endpoints import (
router as spend_management_router, router as spend_management_router,
) )
@ -257,6 +252,11 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,
@ -1765,6 +1765,15 @@ class ProxyConfig:
load_aws_secret_manager(use_aws_secret_manager=True) load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif (
key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
):
from litellm.secret_managers.google_secret_manager import (
GoogleSecretManager,
)
GoogleSecretManager()
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get( key_management_settings = general_settings.get(

View file

@ -4,10 +4,10 @@ from functools import partial
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm import litellm
from litellm import get_secret
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import supports_httpx_timeout

View file

@ -9,7 +9,7 @@ import openai
import litellm import litellm
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.proxy.secret_managers.get_azure_ad_token_provider import ( from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider, get_azure_ad_token_provider,
) )
from litellm.utils import calculate_max_parallel_requests from litellm.utils import calculate_max_parallel_requests

View file

@ -0,0 +1,3 @@
## Supported Secret Managers to read credentials from
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager

View file

@ -14,8 +14,7 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns: Returns:
Callable that returns a temporary authentication token. Callable that returns a temporary authentication token.
""" """
from azure.identity import ClientSecretCredential from azure.identity import ClientSecretCredential, get_bearer_token_provider
from azure.identity import get_bearer_token_provider
try: try:
credential = ClientSecretCredential( credential = ClientSecretCredential(
@ -24,7 +23,9 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
tenant_id=os.environ["AZURE_TENANT_ID"], tenant_id=os.environ["AZURE_TENANT_ID"],
) )
except KeyError as e: except KeyError as e:
raise ValueError("Missing environment variable required by Azure AD workflow.") from e raise ValueError(
"Missing environment variable required by Azure AD workflow."
) from e
return get_bearer_token_provider( return get_bearer_token_provider(
credential, credential,

View file

@ -7,8 +7,11 @@ Requires:
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]` * `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
* `pip install google-cloud-kms` * `pip install google-cloud-kms`
""" """
import litellm, os
import os
from typing import Optional from typing import Optional
import litellm
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem

View file

@ -0,0 +1,116 @@
import base64
import os
from typing import Optional
import litellm
from litellm._logging import verbose_logger
from litellm.caching import InMemoryCache
from litellm.integrations.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem
class GoogleSecretManager(GCSBucketBase):
def __init__(
self,
refresh_interval: Optional[int] = 86400,
always_read_secret_manager: Optional[bool] = False,
) -> None:
"""
Args:
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
"""
from litellm.proxy.proxy_server import premium_user
if premium_user is not True:
raise ValueError(
f"Google Secret Manager requires an Enterprise License {CommonProxyErrors.not_premium_user.value}"
)
super().__init__()
self.PROJECT_ID = os.environ.get("GOOGLE_SECRET_MANAGER_PROJECT_ID", None)
if self.PROJECT_ID is None:
raise ValueError(
"Google Secret Manager requires a project ID, please set 'GOOGLE_SECRET_MANAGER_PROJECT_ID' in your .env"
)
self.sync_httpx_client = _get_httpx_client()
litellm.secret_manager_client = self
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
_refresh_interval = os.environ.get(
"GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL", refresh_interval
)
_refresh_interval = (
int(_refresh_interval) if _refresh_interval else refresh_interval
)
self.cache = InMemoryCache(
default_ttl=_refresh_interval
) # store in memory for 1 day
_always_read_secret_manager = os.environ.get(
"GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER",
)
if (
_always_read_secret_manager
and _always_read_secret_manager.lower() == "true"
):
self.always_read_secret_manager = True
else:
# by default this should be False, we want to use in memory caching for this. It's a bad idea to fetch from secret manager for all requests
self.always_read_secret_manager = always_read_secret_manager or False
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
"""
Retrieve a secret from Google Secret Manager or cache.
Args:
secret_name (str): The name of the secret.
Returns:
str: The secret value if successful, None otherwise.
"""
if self.always_read_secret_manager is not True:
cached_secret = self.cache.get_cache(secret_name)
if cached_secret is not None:
return cached_secret
if secret_name in self.cache.cache_dict:
return cached_secret
_secret_name = (
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
)
headers = self.sync_construct_request_headers()
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
# Send the GET request to retrieve the secret
response = self.sync_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"Google Secret Manager retrieval error: %s", str(response.text)
)
self.cache.set_cache(
secret_name, None
) # Cache that the secret was not found
raise ValueError(
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
)
verbose_logger.debug(
"Google Secret Manager retrieval response status code: %s",
response.status_code,
)
# Parse the JSON response and return the secret value
secret_data = response.json()
_base64_encoded_value = secret_data.get("payload", {}).get("data")
# decode the base64 encoded value
if _base64_encoded_value is not None:
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
self.cache.set_cache(
secret_name, _decoded_value
) # Cache the retrieved secret
return _decoded_value
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")

View file

@ -0,0 +1,276 @@
import ast
import base64
import binascii
import json
import os
import sys
import traceback
from typing import Any, Optional, Union
import httpx
from dotenv import load_dotenv
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.caching import DualCache
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.proxy._types import KeyManagementSystem
oidc_cache = DualCache()
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
else:
raise ValueError(
f"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
if key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
secret_name
)
print_verbose(f"secret from google secret manager: {secret}")
if secret is None:
raise ValueError(
f"No secret found in Google Secret Manager for {secret_name}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
if default_value is not None:
return default_value
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e

View file

@ -83,7 +83,7 @@ async def test_router_init():
) )
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") @patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret( def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
mocked_os_lib: MagicMock, mocked_os_lib: MagicMock,
) -> None: ) -> None:
@ -128,7 +128,7 @@ def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secre
@patch("azure.identity.get_bearer_token_provider") @patch("azure.identity.get_bearer_token_provider")
@patch("azure.identity.ClientSecretCredential") @patch("azure.identity.ClientSecretCredential")
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") @patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_azure_service_principal_with_secret_with_environment_variables( def test_router_init_azure_service_principal_with_secret_with_environment_variables(
mocked_os_lib: MagicMock, mocked_os_lib: MagicMock,
mocked_credential: MagicMock, mocked_credential: MagicMock,

View file

@ -16,10 +16,10 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
from litellm import get_secret
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret
@pytest.mark.skip(reason="AWS Suspended Account") @pytest.mark.skip(reason="AWS Suspended Account")
@ -189,3 +189,54 @@ def test_oidc_env_path():
assert secret_val == secret_value assert secret_val == secret_value
del os.environ[env_var_name] del os.environ[env_var_name]
def test_google_secret_manager():
"""
Test that we can get a secret from Google Secret Manager
"""
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "adroit-crow-413218"
from test_amazing_vertex_completion import load_vertex_ai_credentials
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
# load_vertex_ai_credentials()
secret_manager = GoogleSecretManager()
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="OPENAI_API_KEY"
)
print("secret_val: {}".format(secret_val))
assert (
secret_val == "anything"
), "did not get expected secret value. expect 'anything', got '{}'".format(
secret_val
)
def test_google_secret_manager_read_in_memory():
"""
Test that Google Secret manager returs in memory value when it exists
"""
from test_amazing_vertex_completion import load_vertex_ai_credentials
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
# load_vertex_ai_credentials()
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "adroit-crow-413218"
secret_manager = GoogleSecretManager()
secret_manager.cache.cache_dict["UNIQUE_KEY"] = None
secret_manager.cache.cache_dict["UNIQUE_KEY_2"] = "lite-llm"
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="UNIQUE_KEY"
)
print("secret_val: {}".format(secret_val))
assert secret_val == None
secret_val = secret_manager.get_secret_from_google_secret_manager(
secret_name="UNIQUE_KEY_2"
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"

View file

@ -3712,6 +3712,7 @@ def test_unit_test_custom_stream_wrapper_function_call():
"vertex_ai/claude-3-5-sonnet@20240620", "vertex_ai/claude-3-5-sonnet@20240620",
], ],
) )
@pytest.mark.flaky(retries=3, delay=1)
def test_streaming_tool_calls_valid_json_str(model): def test_streaming_tool_calls_valid_json_str(model):
if "vertex_ai" in model: if "vertex_ai" in model:
from litellm.tests.test_amazing_vertex_completion import ( from litellm.tests.test_amazing_vertex_completion import (

View file

@ -68,6 +68,7 @@ from litellm.litellm_core_utils.redact_messages import (
) )
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
@ -93,8 +94,6 @@ from litellm.types.utils import (
Usage, Usage,
) )
oidc_cache = DualCache()
try: try:
# New and recommended way to access resources # New and recommended way to access resources
from importlib import resources from importlib import resources
@ -8662,250 +8661,6 @@ def exception_type(
raise raised_exc raise raised_exc
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
else:
raise ValueError(
f"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
if default_value is not None:
return default_value
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e
######## Streaming Class ############################ ######## Streaming Class ############################
# wraps the completion stream to return the correct format for the model # wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere # replicate/anthropic/cohere

View file

@ -17,7 +17,7 @@ def test_decrypt_and_reset_env():
os.environ["DATABASE_URL"] = ( os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
) )
from litellm.proxy.secret_managers.aws_secret_manager import ( from litellm.secret_managers.aws_secret_manager import (
decrypt_and_reset_env_var, decrypt_and_reset_env_var,
) )