fix(proxy_cli.py): run aws kms decrypt before starting proxy server

This commit is contained in:
Krrish Dholakia 2024-06-28 16:03:56 -07:00
parent adcd55fca0
commit b84d335624
3 changed files with 45 additions and 44 deletions

View file

@ -1,20 +1,5 @@
#!/bin/sh
echo "Current working directory: $(pwd)"
# Check if SET_AWS_KMS in env
if [ -n "$SET_AWS_KMS" ]; then
# Call Python function to decrypt and reset environment variables
env_vars=$(python -c 'from litellm.proxy.secret_managers.aws_secret_manager import decrypt_and_reset_env_var; env_vars = decrypt_and_reset_env_var();')
echo "Received env_vars: ${env_vars}"
# Export decrypted environment variables to the current Bash environment
while IFS='=' read -r name value; do
export "$name=$value"
done <<< "$env_vars"
fi
echo "DATABASE_URL post kms: $($DATABASE_URL)"
# Check if DATABASE_URL is not set
if [ -z "$DATABASE_URL" ]; then
# Check if all required variables are provided
@ -28,38 +13,35 @@ if [ -z "$DATABASE_URL" ]; then
fi
fi
echo "DATABASE_URL: $($DATABASE_URL)"
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
if [ -z "$DIRECT_URL" ]; then
export DIRECT_URL=$DATABASE_URL
fi
# # Apply migrations
# retry_count=0
# max_retries=3
# exit_code=1
# Apply migrations
retry_count=0
max_retries=3
exit_code=1
# until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
# do
# retry_count=$((retry_count+1))
# echo "Attempt $retry_count..."
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
do
retry_count=$((retry_count+1))
echo "Attempt $retry_count..."
# # Run the Prisma db push command
# prisma db push --accept-data-loss
# Run the Prisma db push command
prisma db push --accept-data-loss
# exit_code=$?
exit_code=$?
# if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
# echo "Retrying in 10 seconds..."
# sleep 10
# fi
# done
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
echo "Retrying in 10 seconds..."
sleep 10
fi
done
# if [ $exit_code -ne 0 ]; then
# echo "Unable to push database changes after $max_retries retries."
# exit 1
# fi
if [ $exit_code -ne 0 ]; then
echo "Unable to push database changes after $max_retries retries."
exit 1
fi
echo "Database push successful!"

View file

@ -442,6 +442,20 @@ def run_server(
db_connection_pool_limit = 100
db_connection_timeout = 60
### DECRYPT ENV VAR ###
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
if (
os.getenv("USE_AWS_KMS", None) is not None
and os.getenv("USE_AWS_KMS") == "True"
):
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
if config is not None:
"""
Allow user to pass in db url via config
@ -459,6 +473,7 @@ def run_server(
proxy_config = ProxyConfig()
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
### LITELLM SETTINGS ###
litellm_settings = _config.get("litellm_settings", None)
if (

View file

@ -62,7 +62,7 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
raise e
class AWSKeyManagementService:
class AWSKeyManagementService_V2:
"""
V2 Clean Class for decrypting keys from AWS KeyManagementService
"""
@ -77,6 +77,12 @@ class AWSKeyManagementService:
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
## CHECK IF LICENSE IN ENV ## - premium feature
if os.getenv("LITELLM_LICENSE", None) is None:
raise ValueError(
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
)
def load_aws_kms(self, use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
@ -88,8 +94,6 @@ class AWSKeyManagementService:
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
return kms_client
except Exception as e:
raise e
@ -131,13 +135,13 @@ class AWSKeyManagementService:
"""
- look for all values in the env with `aws_kms/<hashed_key>`
- decrypt keys
- rewrite env var with decrypted key
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
"""
def decrypt_and_reset_env_var() -> dict:
def decrypt_env_var() -> dict[str, Any]:
# setup client class
aws_kms = AWSKeyManagementService()
aws_kms = AWSKeyManagementService_V2()
# iterate through env - for `aws_kms/`
new_values = {}
for k, v in os.environ.items():