mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* add SecretManager to httpxSpecialProvider * fix importing AWSSecretsManagerV2 * add unit testing for writing keys to AWS secret manager * use KeyManagementEventHooks for key/generated events * us event hooks for key management endpoints * working AWSSecretsManagerV2 * fix write secret to AWS secret manager on /key/generate * fix KeyManagementSettings * use tasks for key management hooks * add async_delete_secret * add test for async_delete_secret * use _delete_virtual_keys_from_secret_manager * fix test secret manager * test_key_generate_with_secret_manager_call * fix check for key_management_settings * sync_read_secret * test_aws_secret_manager * fix sync_read_secret * use helper to check when _should_read_secret_from_secret_manager * test_get_secret_with_access_mode * test - handle eol model claude-2, use claude-2.1 instead * docs AWS secret manager * fix test_read_nonexistent_secret * fix test_supports_response_schema * ci/cd run again
143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
"""
|
|
This is a file for the AWS Secret Manager Integration
|
|
|
|
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
|
|
|
Requires:
|
|
* `os.environ["AWS_REGION_NAME"],
|
|
* `pip install boto3>=1.28.57`
|
|
"""
|
|
|
|
import ast
|
|
import base64
|
|
import os
|
|
import re
|
|
from typing import Any, Dict, Optional
|
|
|
|
import litellm
|
|
from litellm.proxy._types import KeyManagementSystem
|
|
|
|
|
|
def validate_environment():
|
|
if "AWS_REGION_NAME" not in os.environ:
|
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
|
|
|
|
|
def load_aws_kms(use_aws_kms: Optional[bool]):
|
|
if use_aws_kms is None or use_aws_kms is False:
|
|
return
|
|
try:
|
|
import boto3
|
|
|
|
validate_environment()
|
|
|
|
# Create a Secrets Manager client
|
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
|
|
|
litellm.secret_manager_client = kms_client
|
|
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
|
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
class AWSKeyManagementService_V2:
|
|
"""
|
|
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.validate_environment()
|
|
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
|
|
|
def validate_environment(
|
|
self,
|
|
):
|
|
if "AWS_REGION_NAME" not in os.environ:
|
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
|
|
|
## CHECK IF LICENSE IN ENV ## - premium feature
|
|
is_litellm_license_in_env: bool = False
|
|
|
|
if os.getenv("LITELLM_LICENSE", None) is not None:
|
|
is_litellm_license_in_env = True
|
|
elif os.getenv("LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE", None) is not None:
|
|
is_litellm_license_in_env = True
|
|
if is_litellm_license_in_env is False:
|
|
raise ValueError(
|
|
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
|
)
|
|
|
|
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
|
if use_aws_kms is None or use_aws_kms is False:
|
|
return
|
|
try:
|
|
import boto3
|
|
|
|
validate_environment()
|
|
|
|
# Create a Secrets Manager client
|
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
|
|
|
return kms_client
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def decrypt_value(self, secret_name: str) -> Any:
|
|
if self.kms_client is None:
|
|
raise ValueError("kms_client is None")
|
|
encrypted_value = os.getenv(secret_name, None)
|
|
if encrypted_value is None:
|
|
raise Exception(
|
|
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
|
)
|
|
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
|
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
|
|
|
# Decode the base64 encoded ciphertext
|
|
ciphertext_blob = base64.b64decode(encrypted_value)
|
|
|
|
# Set up the parameters for the decrypt call
|
|
params = {"CiphertextBlob": ciphertext_blob}
|
|
# Perform the decryption
|
|
response = self.kms_client.decrypt(**params)
|
|
|
|
# Extract and decode the plaintext
|
|
plaintext = response["Plaintext"]
|
|
secret = plaintext.decode("utf-8")
|
|
if isinstance(secret, str):
|
|
secret = secret.strip()
|
|
try:
|
|
secret_value_as_bool = ast.literal_eval(secret)
|
|
if isinstance(secret_value_as_bool, bool):
|
|
return secret_value_as_bool
|
|
except Exception:
|
|
pass
|
|
|
|
return secret
|
|
|
|
|
|
"""
|
|
- look for all values in the env with `aws_kms/<hashed_key>`
|
|
- decrypt keys
|
|
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
|
"""
|
|
|
|
|
|
def decrypt_env_var() -> Dict[str, Any]:
|
|
# setup client class
|
|
aws_kms = AWSKeyManagementService_V2()
|
|
# iterate through env - for `aws_kms/`
|
|
new_values = {}
|
|
for k, v in os.environ.items():
|
|
if (
|
|
k is not None
|
|
and isinstance(k, str)
|
|
and k.lower().startswith("litellm_secret_aws_kms")
|
|
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
|
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
|
# reset env var
|
|
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
|
|
new_values[k] = decrypted_value
|
|
|
|
return new_values
|