feat(aws_secret_manager.py): allows user to keep a hash of the proxy master key in their env

This commit is contained in:
Krrish Dholakia 2024-06-06 15:32:51 -07:00
parent 707a41fd6c
commit a2da2a8f16
5 changed files with 62 additions and 8 deletions

View file

@ -56,8 +56,10 @@ router_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
json_logs: true
general_settings: general_settings:
alerting: ["email"] alerting: ["email"]
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager" AWS_SECRET_MANAGER = "aws_secret_manager"
LOCAL = "local" LOCAL = "local"
AWS_KMS = "aws_kms"
class KeyManagementSettings(LiteLLMBase): class KeyManagementSettings(LiteLLMBase):

View file

@ -112,7 +112,10 @@ from litellm import (
CreateFileRequest, CreateFileRequest,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
@ -2736,10 +2739,12 @@ class ProxyConfig:
load_google_kms(use_google_kms=True) load_google_kms(use_google_kms=True)
elif ( elif (
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
### LOAD FROM AWS SECRET MANAGER ### ### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True) load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get( key_management_settings = general_settings.get(
@ -2773,6 +2778,7 @@ class ProxyConfig:
master_key = general_settings.get( master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
) )
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = litellm.get_secret(master_key)
if not isinstance(master_key, str): if not isinstance(master_key, str):
@ -4098,6 +4104,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
global general_settings, user_debug, proxy_logging_obj, llm_model_list global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {} data = {}
try: try:
body = await request.body() body = await request.body()

View file

@ -8,6 +8,7 @@ Requires:
* `pip install boto3>=1.28.57` * `pip install boto3>=1.28.57`
""" """
import boto3.session
import litellm, os import litellm, os
from typing import Optional from typing import Optional
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
except Exception as e: except Exception as e:
raise e raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
except Exception as e:
raise e

View file

@ -10052,6 +10052,8 @@ def get_secret(
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
@ -10139,13 +10141,13 @@ def get_secret(
key_manager = "local" key_manager = "local"
if ( if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__ or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient" == "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value secret = client.get_secret(secret_name).value
elif ( elif (
key_manager == KeyManagementSystem.GOOGLE_KMS key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient" or client.__class__.__name__ == "KeyManagementServiceClient"
): ):
encrypted_secret: Any = os.getenv(secret_name) encrypted_secret: Any = os.getenv(secret_name)
@ -10173,6 +10175,25 @@ def get_secret(
secret = response.plaintext.decode( secret = response.plaintext.decode(
"utf-8" "utf-8"
) # assumes the original value was encoded with utf-8 ) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try: try:
get_secret_value_response = client.get_secret_value( get_secret_value_response = client.get_secret_value(
@ -10193,10 +10214,14 @@ def get_secret(
for k, v in secret_dict.items(): for k, v in secret_dict.items():
secret = v secret = v
print_verbose(f"secret: {secret}") print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}") verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
try: try:
secret_value_as_bool = ast.literal_eval(secret) secret_value_as_bool = ast.literal_eval(secret)