forked from phoenix/litellm-mirror
fix(proxy_cli.py): run aws kms decrypt before starting proxy server
This commit is contained in:
parent
adcd55fca0
commit
b84d335624
3 changed files with 45 additions and 44 deletions
|
@ -1,20 +1,5 @@
|
||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
echo "Current working directory: $(pwd)"
|
|
||||||
|
|
||||||
|
|
||||||
# Check if SET_AWS_KMS in env
|
|
||||||
if [ -n "$SET_AWS_KMS" ]; then
|
|
||||||
# Call Python function to decrypt and reset environment variables
|
|
||||||
env_vars=$(python -c 'from litellm.proxy.secret_managers.aws_secret_manager import decrypt_and_reset_env_var; env_vars = decrypt_and_reset_env_var();')
|
|
||||||
echo "Received env_vars: ${env_vars}"
|
|
||||||
# Export decrypted environment variables to the current Bash environment
|
|
||||||
while IFS='=' read -r name value; do
|
|
||||||
export "$name=$value"
|
|
||||||
done <<< "$env_vars"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "DATABASE_URL post kms: $($DATABASE_URL)"
|
|
||||||
# Check if DATABASE_URL is not set
|
# Check if DATABASE_URL is not set
|
||||||
if [ -z "$DATABASE_URL" ]; then
|
if [ -z "$DATABASE_URL" ]; then
|
||||||
# Check if all required variables are provided
|
# Check if all required variables are provided
|
||||||
|
@ -28,38 +13,35 @@ if [ -z "$DATABASE_URL" ]; then
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "DATABASE_URL: $($DATABASE_URL)"
|
|
||||||
|
|
||||||
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
||||||
if [ -z "$DIRECT_URL" ]; then
|
if [ -z "$DIRECT_URL" ]; then
|
||||||
export DIRECT_URL=$DATABASE_URL
|
export DIRECT_URL=$DATABASE_URL
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# # Apply migrations
|
# Apply migrations
|
||||||
# retry_count=0
|
retry_count=0
|
||||||
# max_retries=3
|
max_retries=3
|
||||||
# exit_code=1
|
exit_code=1
|
||||||
|
|
||||||
# until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
||||||
# do
|
do
|
||||||
# retry_count=$((retry_count+1))
|
retry_count=$((retry_count+1))
|
||||||
# echo "Attempt $retry_count..."
|
echo "Attempt $retry_count..."
|
||||||
|
|
||||||
# # Run the Prisma db push command
|
# Run the Prisma db push command
|
||||||
# prisma db push --accept-data-loss
|
prisma db push --accept-data-loss
|
||||||
|
|
||||||
# exit_code=$?
|
exit_code=$?
|
||||||
|
|
||||||
# if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
||||||
# echo "Retrying in 10 seconds..."
|
echo "Retrying in 10 seconds..."
|
||||||
# sleep 10
|
sleep 10
|
||||||
# fi
|
fi
|
||||||
# done
|
done
|
||||||
|
|
||||||
# if [ $exit_code -ne 0 ]; then
|
if [ $exit_code -ne 0 ]; then
|
||||||
# echo "Unable to push database changes after $max_retries retries."
|
echo "Unable to push database changes after $max_retries retries."
|
||||||
# exit 1
|
exit 1
|
||||||
# fi
|
fi
|
||||||
|
|
||||||
echo "Database push successful!"
|
echo "Database push successful!"
|
||||||
|
|
||||||
|
|
|
@ -442,6 +442,20 @@ def run_server(
|
||||||
|
|
||||||
db_connection_pool_limit = 100
|
db_connection_pool_limit = 100
|
||||||
db_connection_timeout = 60
|
db_connection_timeout = 60
|
||||||
|
### DECRYPT ENV VAR ###
|
||||||
|
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
||||||
|
if (
|
||||||
|
os.getenv("USE_AWS_KMS", None) is not None
|
||||||
|
and os.getenv("USE_AWS_KMS") == "True"
|
||||||
|
):
|
||||||
|
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||||
|
new_env_var = decrypt_env_var()
|
||||||
|
|
||||||
|
for k, v in new_env_var.items():
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
"""
|
"""
|
||||||
Allow user to pass in db url via config
|
Allow user to pass in db url via config
|
||||||
|
@ -459,6 +473,7 @@ def run_server(
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
||||||
|
|
||||||
### LITELLM SETTINGS ###
|
### LITELLM SETTINGS ###
|
||||||
litellm_settings = _config.get("litellm_settings", None)
|
litellm_settings = _config.get("litellm_settings", None)
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -62,7 +62,7 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class AWSKeyManagementService:
|
class AWSKeyManagementService_V2:
|
||||||
"""
|
"""
|
||||||
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||||
"""
|
"""
|
||||||
|
@ -77,6 +77,12 @@ class AWSKeyManagementService:
|
||||||
if "AWS_REGION_NAME" not in os.environ:
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
## CHECK IF LICENSE IN ENV ## - premium feature
|
||||||
|
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||||
|
)
|
||||||
|
|
||||||
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||||
if use_aws_kms is None or use_aws_kms is False:
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
return
|
return
|
||||||
|
@ -88,8 +94,6 @@ class AWSKeyManagementService:
|
||||||
# Create a Secrets Manager client
|
# Create a Secrets Manager client
|
||||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||||
|
|
||||||
litellm.secret_manager_client = kms_client
|
|
||||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
|
||||||
return kms_client
|
return kms_client
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -131,13 +135,13 @@ class AWSKeyManagementService:
|
||||||
"""
|
"""
|
||||||
- look for all values in the env with `aws_kms/<hashed_key>`
|
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||||
- decrypt keys
|
- decrypt keys
|
||||||
- rewrite env var with decrypted key
|
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def decrypt_and_reset_env_var() -> dict:
|
def decrypt_env_var() -> dict[str, Any]:
|
||||||
# setup client class
|
# setup client class
|
||||||
aws_kms = AWSKeyManagementService()
|
aws_kms = AWSKeyManagementService_V2()
|
||||||
# iterate through env - for `aws_kms/`
|
# iterate through env - for `aws_kms/`
|
||||||
new_values = {}
|
new_values = {}
|
||||||
for k, v in os.environ.items():
|
for k, v in os.environ.items():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue