forked from phoenix/litellm-mirror
fix(initial-commit): decrypts aws keys in entrypoint.sh
This commit is contained in:
parent
3dc578555c
commit
adcd55fca0
3 changed files with 186 additions and 22 deletions
|
@ -1,5 +1,20 @@
|
|||
#!/bin/sh
|
||||
|
||||
echo "Current working directory: $(pwd)"
|
||||
|
||||
|
||||
# Check if SET_AWS_KMS in env
|
||||
if [ -n "$SET_AWS_KMS" ]; then
|
||||
# Call Python function to decrypt and reset environment variables
|
||||
env_vars=$(python -c 'from litellm.proxy.secret_managers.aws_secret_manager import decrypt_and_reset_env_var; env_vars = decrypt_and_reset_env_var();')
|
||||
echo "Received env_vars: ${env_vars}"
|
||||
# Export decrypted environment variables to the current Bash environment
|
||||
while IFS='=' read -r name value; do
|
||||
export "$name=$value"
|
||||
done <<< "$env_vars"
|
||||
fi
|
||||
|
||||
echo "DATABASE_URL post kms: $($DATABASE_URL)"
|
||||
# Check if DATABASE_URL is not set
|
||||
if [ -z "$DATABASE_URL" ]; then
|
||||
# Check if all required variables are provided
|
||||
|
@ -13,36 +28,38 @@ if [ -z "$DATABASE_URL" ]; then
|
|||
fi
|
||||
fi
|
||||
|
||||
echo "DATABASE_URL: $($DATABASE_URL)"
|
||||
|
||||
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
||||
if [ -z "$DIRECT_URL" ]; then
|
||||
export DIRECT_URL=$DATABASE_URL
|
||||
fi
|
||||
|
||||
# Apply migrations
|
||||
retry_count=0
|
||||
max_retries=3
|
||||
exit_code=1
|
||||
# # Apply migrations
|
||||
# retry_count=0
|
||||
# max_retries=3
|
||||
# exit_code=1
|
||||
|
||||
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
||||
do
|
||||
retry_count=$((retry_count+1))
|
||||
echo "Attempt $retry_count..."
|
||||
# until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
||||
# do
|
||||
# retry_count=$((retry_count+1))
|
||||
# echo "Attempt $retry_count..."
|
||||
|
||||
# Run the Prisma db push command
|
||||
prisma db push --accept-data-loss
|
||||
# # Run the Prisma db push command
|
||||
# prisma db push --accept-data-loss
|
||||
|
||||
exit_code=$?
|
||||
# exit_code=$?
|
||||
|
||||
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
||||
echo "Retrying in 10 seconds..."
|
||||
sleep 10
|
||||
fi
|
||||
done
|
||||
# if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
||||
# echo "Retrying in 10 seconds..."
|
||||
# sleep 10
|
||||
# fi
|
||||
# done
|
||||
|
||||
if [ $exit_code -ne 0 ]; then
|
||||
echo "Unable to push database changes after $max_retries retries."
|
||||
exit 1
|
||||
fi
|
||||
# if [ $exit_code -ne 0 ]; then
|
||||
# echo "Unable to push database changes after $max_retries retries."
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
echo "Database push successful!"
|
||||
|
||||
|
|
|
@ -8,9 +8,12 @@ Requires:
|
|||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import litellm
|
||||
import ast
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
|
@ -57,3 +60,90 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
|
|||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AWSKeyManagementService:
|
||||
"""
|
||||
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.validate_environment()
|
||||
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
return kms_client
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def decrypt_value(self, secret_name: str) -> Any:
|
||||
if self.kms_client is None:
|
||||
raise ValueError("kms_client is None")
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||
)
|
||||
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
||||
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
||||
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = self.kms_client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
"""
|
||||
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||
- decrypt keys
|
||||
- rewrite env var with decrypted key
|
||||
"""
|
||||
|
||||
|
||||
def decrypt_and_reset_env_var() -> dict:
|
||||
# setup client class
|
||||
aws_kms = AWSKeyManagementService()
|
||||
# iterate through env - for `aws_kms/`
|
||||
new_values = {}
|
||||
for k, v in os.environ.items():
|
||||
if v is not None and isinstance(v, str) and v.startswith("aws_kms/"):
|
||||
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||
# reset env var
|
||||
new_values[k] = decrypted_value
|
||||
|
||||
return new_values
|
||||
|
|
57
tests/test_entrypoint.py
Normal file
57
tests/test_entrypoint.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# What is this?
|
||||
## Unit tests for 'entrypoint.sh'
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
import subprocess
|
||||
|
||||
|
||||
def test_decrypt_and_reset_env():
|
||||
os.environ["DATABASE_URL"] = (
|
||||
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||
)
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
decrypt_and_reset_env_var,
|
||||
)
|
||||
|
||||
decrypt_and_reset_env_var()
|
||||
|
||||
assert os.environ["DATABASE_URL"] is not None
|
||||
assert isinstance(os.environ["DATABASE_URL"], str)
|
||||
assert not os.environ["DATABASE_URL"].startswith("aws_kms/")
|
||||
|
||||
print("DATABASE_URL={}".format(os.environ["DATABASE_URL"]))
|
||||
|
||||
|
||||
def test_entrypoint_decrypt_and_reset():
|
||||
os.environ["DATABASE_URL"] = (
|
||||
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||
)
|
||||
command = "./entrypoint.sh"
|
||||
directory = ".." # Relative to the current directory
|
||||
|
||||
# Run the command using subprocess
|
||||
result = subprocess.run(
|
||||
command, shell=True, cwd=directory, capture_output=True, text=True
|
||||
)
|
||||
|
||||
# Print the output for debugging purposes
|
||||
print("STDOUT:", result.stdout)
|
||||
print("STDERR:", result.stderr)
|
||||
|
||||
# Assert the script ran successfully
|
||||
assert result.returncode == 0, "The shell script did not execute successfully"
|
||||
assert (
|
||||
"DECRYPTS VALUE" in result.stdout
|
||||
), "Expected output not found in script output"
|
||||
assert (
|
||||
"Database push successful!" in result.stdout
|
||||
), "Expected output not found in script output"
|
||||
|
||||
assert False
|
Loading…
Add table
Add a link
Reference in a new issue