Merge pull request #4437 from BerriAI/litellm_aws_kms_decryption

fix(initial-commit): decrypts aws keys in entrypoint.sh
This commit is contained in:
Krish Dholakia 2024-06-28 21:10:52 -07:00 committed by GitHub
commit d0c89ddbe3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 247 additions and 46 deletions

View file

@ -1,48 +1,13 @@
#!/bin/sh
#!/bin/bash
echo $(pwd)
# Check if DATABASE_URL is not set
if [ -z "$DATABASE_URL" ]; then
# Check if all required variables are provided
if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then
# Construct DATABASE_URL from the provided variables
DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}"
export DATABASE_URL
# Run the Python migration script
python3 litellm/proxy/prisma_migration.py
# Check if the Python script executed successfully
if [ $? -eq 0 ]; then
echo "Migration script ran successfully!"
else
echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL."
echo "Migration script failed!"
exit 1
fi
fi
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
if [ -z "$DIRECT_URL" ]; then
export DIRECT_URL=$DATABASE_URL
fi
# Apply migrations
retry_count=0
max_retries=3
exit_code=1
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
do
retry_count=$((retry_count+1))
echo "Attempt $retry_count..."
# Run the Prisma db push command
prisma db push --accept-data-loss
exit_code=$?
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
echo "Retrying in 10 seconds..."
sleep 10
fi
done
if [ $exit_code -ne 0 ]; then
echo "Unable to push database changes after $max_retries retries."
exit 1
fi
echo "Database push successful!"

View file

@ -0,0 +1,68 @@
# What is this?
## Script to apply initial prisma migration on Docker setup
import os
import subprocess
import sys
import time
sys.path.insert(
0, os.path.abspath("./")
) # Adds the parent directory to the system path
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
# Check if DATABASE_URL is not set
database_url = os.getenv("DATABASE_URL")
if not database_url:
# Check if all required variables are provided
database_host = os.getenv("DATABASE_HOST")
database_username = os.getenv("DATABASE_USERNAME")
database_password = os.getenv("DATABASE_PASSWORD")
database_name = os.getenv("DATABASE_NAME")
if database_host and database_username and database_password and database_name:
# Construct DATABASE_URL from the provided variables
database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}"
os.environ["DATABASE_URL"] = database_url
else:
print( # noqa
"Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa
)
exit(1)
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
direct_url = os.getenv("DIRECT_URL")
if not direct_url:
os.environ["DIRECT_URL"] = database_url
# Apply migrations
retry_count = 0
max_retries = 3
exit_code = 1
while retry_count < max_retries and exit_code != 0:
retry_count += 1
print(f"Attempt {retry_count}...") # noqa
# Run the Prisma db push command
result = subprocess.run(
["prisma", "db", "push", "--accept-data-loss"], capture_output=True
)
exit_code = result.returncode
if exit_code != 0 and retry_count < max_retries:
print("Retrying in 10 seconds...") # noqa
time.sleep(10)
if exit_code != 0:
print(f"Unable to push database changes after {max_retries} retries.") # noqa
exit(1)
print("Database push successful!") # noqa

View file

@ -442,6 +442,20 @@ def run_server(
db_connection_pool_limit = 100
db_connection_timeout = 60
### DECRYPT ENV VAR ###
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
if (
os.getenv("USE_AWS_KMS", None) is not None
and os.getenv("USE_AWS_KMS") == "True"
):
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
if config is not None:
"""
Allow user to pass in db url via config
@ -459,6 +473,7 @@ def run_server(
proxy_config = ProxyConfig()
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
### LITELLM SETTINGS ###
litellm_settings = _config.get("litellm_settings", None)
if (

View file

@ -8,9 +8,12 @@ Requires:
* `pip install boto3>=1.28.57`
"""
import litellm
import ast
import base64
import os
from typing import Optional
from typing import Any, Optional
import litellm
from litellm.proxy._types import KeyManagementSystem
@ -57,3 +60,94 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
except Exception as e:
raise e
class AWSKeyManagementService_V2:
"""
V2 Clean Class for decrypting keys from AWS KeyManagementService
"""
def __init__(self) -> None:
self.validate_environment()
self.kms_client = self.load_aws_kms(use_aws_kms=True)
def validate_environment(
self,
):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
## CHECK IF LICENSE IN ENV ## - premium feature
if os.getenv("LITELLM_LICENSE", None) is None:
raise ValueError(
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
)
def load_aws_kms(self, use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
return kms_client
except Exception as e:
raise e
def decrypt_value(self, secret_name: str) -> Any:
if self.kms_client is None:
raise ValueError("kms_client is None")
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
)
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
encrypted_value = encrypted_value.replace("aws_kms/", "")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = self.kms_client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
except Exception:
pass
return secret
"""
- look for all values in the env with `aws_kms/<hashed_key>`
- decrypt keys
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
"""
def decrypt_env_var() -> dict[str, Any]:
# setup client class
aws_kms = AWSKeyManagementService_V2()
# iterate through env - for `aws_kms/`
new_values = {}
for k, v in os.environ.items():
if v is not None and isinstance(v, str) and v.startswith("aws_kms/"):
decrypted_value = aws_kms.decrypt_value(secret_name=k)
# reset env var
new_values[k] = decrypted_value
return new_values

59
tests/test_entrypoint.py Normal file
View file

@ -0,0 +1,59 @@
# What is this?
## Unit tests for 'entrypoint.sh'
import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
import subprocess
@pytest.mark.skip(reason="local test")
def test_decrypt_and_reset_env():
os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
)
from litellm.proxy.secret_managers.aws_secret_manager import (
decrypt_and_reset_env_var,
)
decrypt_and_reset_env_var()
assert os.environ["DATABASE_URL"] is not None
assert isinstance(os.environ["DATABASE_URL"], str)
assert not os.environ["DATABASE_URL"].startswith("aws_kms/")
print("DATABASE_URL={}".format(os.environ["DATABASE_URL"]))
@pytest.mark.skip(reason="local test")
def test_entrypoint_decrypt_and_reset():
os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
)
command = "./entrypoint.sh"
directory = ".." # Relative to the current directory
# Run the command using subprocess
result = subprocess.run(
command, shell=True, cwd=directory, capture_output=True, text=True
)
# Print the output for debugging purposes
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
# Assert the script ran successfully
assert result.returncode == 0, "The shell script did not execute successfully"
assert (
"DECRYPTS VALUE" in result.stdout
), "Expected output not found in script output"
assert (
"Database push successful!" in result.stdout
), "Expected output not found in script output"
assert False