From adcd55fca0c8ecb0854a818df6839023090691e9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 26 Jun 2024 22:33:26 -0700 Subject: [PATCH 1/3] fix(initial-commit): decrypts aws keys in entrypoint.sh --- entrypoint.sh | 57 +++++++---- .../secret_managers/aws_secret_manager.py | 94 ++++++++++++++++++- tests/test_entrypoint.py | 57 +++++++++++ 3 files changed, 186 insertions(+), 22 deletions(-) create mode 100644 tests/test_entrypoint.py diff --git a/entrypoint.sh b/entrypoint.sh index 80adf8d077..a76f126a30 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,5 +1,20 @@ #!/bin/sh +echo "Current working directory: $(pwd)" + + +# Check if SET_AWS_KMS in env +if [ -n "$SET_AWS_KMS" ]; then + # Call Python function to decrypt and reset environment variables + env_vars=$(python -c 'from litellm.proxy.secret_managers.aws_secret_manager import decrypt_and_reset_env_var; env_vars = decrypt_and_reset_env_var();') + echo "Received env_vars: ${env_vars}" + # Export decrypted environment variables to the current Bash environment + while IFS='=' read -r name value; do + export "$name=$value" + done <<< "$env_vars" +fi + +echo "DATABASE_URL post kms: $($DATABASE_URL)" # Check if DATABASE_URL is not set if [ -z "$DATABASE_URL" ]; then # Check if all required variables are provided @@ -13,36 +28,38 @@ if [ -z "$DATABASE_URL" ]; then fi fi +echo "DATABASE_URL: $($DATABASE_URL)" + # Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations if [ -z "$DIRECT_URL" ]; then export DIRECT_URL=$DATABASE_URL fi -# Apply migrations -retry_count=0 -max_retries=3 -exit_code=1 +# # Apply migrations +# retry_count=0 +# max_retries=3 +# exit_code=1 -until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -do - retry_count=$((retry_count+1)) - echo "Attempt $retry_count..." +# until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] +# do +# retry_count=$((retry_count+1)) +# echo "Attempt $retry_count..." - # Run the Prisma db push command - prisma db push --accept-data-loss +# # Run the Prisma db push command +# prisma db push --accept-data-loss - exit_code=$? +# exit_code=$? - if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then - echo "Retrying in 10 seconds..." - sleep 10 - fi -done +# if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then +# echo "Retrying in 10 seconds..." +# sleep 10 +# fi +# done -if [ $exit_code -ne 0 ]; then - echo "Unable to push database changes after $max_retries retries." - exit 1 -fi +# if [ $exit_code -ne 0 ]; then +# echo "Unable to push database changes after $max_retries retries." +# exit 1 +# fi echo "Database push successful!" diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py index 8dd6772cf7..49c79b68b6 100644 --- a/litellm/proxy/secret_managers/aws_secret_manager.py +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -8,9 +8,12 @@ Requires: * `pip install boto3>=1.28.57` """ -import litellm +import ast +import base64 import os -from typing import Optional +from typing import Any, Optional + +import litellm from litellm.proxy._types import KeyManagementSystem @@ -57,3 +60,90 @@ def load_aws_kms(use_aws_kms: Optional[bool]): except Exception as e: raise e + + +class AWSKeyManagementService: + """ + V2 Clean Class for decrypting keys from AWS KeyManagementService + """ + + def __init__(self) -> None: + self.validate_environment() + self.kms_client = self.load_aws_kms(use_aws_kms=True) + + def validate_environment( + self, + ): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + def load_aws_kms(self, use_aws_kms: Optional[bool]): + if use_aws_kms is None or use_aws_kms is False: + return + try: + import boto3 + + validate_environment() + + # Create a Secrets Manager client + kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) + + litellm.secret_manager_client = kms_client + litellm._key_management_system = KeyManagementSystem.AWS_KMS + return kms_client + except Exception as e: + raise e + + def decrypt_value(self, secret_name: str) -> Any: + if self.kms_client is None: + raise ValueError("kms_client is None") + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception( + "AWS KMS - Encrypted Value of Key={} is None".format(secret_name) + ) + if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"): + encrypted_value = encrypted_value.replace("aws_kms/", "") + + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + # Perform the decryption + response = self.kms_client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") + if isinstance(secret, str): + secret = secret.strip() + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + except Exception: + pass + + return secret + + +""" +- look for all values in the env with `aws_kms/` +- decrypt keys +- rewrite env var with decrypted key +""" + + +def decrypt_and_reset_env_var() -> dict: + # setup client class + aws_kms = AWSKeyManagementService() + # iterate through env - for `aws_kms/` + new_values = {} + for k, v in os.environ.items(): + if v is not None and isinstance(v, str) and v.startswith("aws_kms/"): + decrypted_value = aws_kms.decrypt_value(secret_name=k) + # reset env var + new_values[k] = decrypted_value + + return new_values diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py new file mode 100644 index 0000000000..cbf14c6ead --- /dev/null +++ b/tests/test_entrypoint.py @@ -0,0 +1,57 @@ +# What is this? +## Unit tests for 'entrypoint.sh' + +import pytest +import sys +import os + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm +import subprocess + + +def test_decrypt_and_reset_env(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + from litellm.proxy.secret_managers.aws_secret_manager import ( + decrypt_and_reset_env_var, + ) + + decrypt_and_reset_env_var() + + assert os.environ["DATABASE_URL"] is not None + assert isinstance(os.environ["DATABASE_URL"], str) + assert not os.environ["DATABASE_URL"].startswith("aws_kms/") + + print("DATABASE_URL={}".format(os.environ["DATABASE_URL"])) + + +def test_entrypoint_decrypt_and_reset(): + os.environ["DATABASE_URL"] = ( + "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" + ) + command = "./entrypoint.sh" + directory = ".." # Relative to the current directory + + # Run the command using subprocess + result = subprocess.run( + command, shell=True, cwd=directory, capture_output=True, text=True + ) + + # Print the output for debugging purposes + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + + # Assert the script ran successfully + assert result.returncode == 0, "The shell script did not execute successfully" + assert ( + "DECRYPTS VALUE" in result.stdout + ), "Expected output not found in script output" + assert ( + "Database push successful!" in result.stdout + ), "Expected output not found in script output" + + assert False From b84d335624dc9265b9154f296aa165d6e6393676 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 28 Jun 2024 16:03:56 -0700 Subject: [PATCH 2/3] fix(proxy_cli.py): run aws kms decrypt before starting proxy server --- entrypoint.sh | 58 +++++++------------ litellm/proxy/proxy_cli.py | 15 +++++ .../secret_managers/aws_secret_manager.py | 16 +++-- 3 files changed, 45 insertions(+), 44 deletions(-) diff --git a/entrypoint.sh b/entrypoint.sh index a76f126a30..6e47dde12c 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,20 +1,5 @@ #!/bin/sh -echo "Current working directory: $(pwd)" - - -# Check if SET_AWS_KMS in env -if [ -n "$SET_AWS_KMS" ]; then - # Call Python function to decrypt and reset environment variables - env_vars=$(python -c 'from litellm.proxy.secret_managers.aws_secret_manager import decrypt_and_reset_env_var; env_vars = decrypt_and_reset_env_var();') - echo "Received env_vars: ${env_vars}" - # Export decrypted environment variables to the current Bash environment - while IFS='=' read -r name value; do - export "$name=$value" - done <<< "$env_vars" -fi - -echo "DATABASE_URL post kms: $($DATABASE_URL)" # Check if DATABASE_URL is not set if [ -z "$DATABASE_URL" ]; then # Check if all required variables are provided @@ -28,38 +13,35 @@ if [ -z "$DATABASE_URL" ]; then fi fi -echo "DATABASE_URL: $($DATABASE_URL)" - # Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations if [ -z "$DIRECT_URL" ]; then export DIRECT_URL=$DATABASE_URL fi -# # Apply migrations -# retry_count=0 -# max_retries=3 -# exit_code=1 +# Apply migrations +retry_count=0 +max_retries=3 +exit_code=1 -# until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -# do -# retry_count=$((retry_count+1)) -# echo "Attempt $retry_count..." +until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] +do + retry_count=$((retry_count+1)) + echo "Attempt $retry_count..." -# # Run the Prisma db push command -# prisma db push --accept-data-loss + # Run the Prisma db push command + prisma db push --accept-data-loss -# exit_code=$? + exit_code=$? -# if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then -# echo "Retrying in 10 seconds..." -# sleep 10 -# fi -# done + if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then + echo "Retrying in 10 seconds..." + sleep 10 + fi +done -# if [ $exit_code -ne 0 ]; then -# echo "Unable to push database changes after $max_retries retries." -# exit 1 -# fi +if [ $exit_code -ne 0 ]; then + echo "Unable to push database changes after $max_retries retries." + exit 1 +fi echo "Database push successful!" - diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 6e6d1f4a9e..e987046428 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -442,6 +442,20 @@ def run_server( db_connection_pool_limit = 100 db_connection_timeout = 60 + ### DECRYPT ENV VAR ### + + from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + + if ( + os.getenv("USE_AWS_KMS", None) is not None + and os.getenv("USE_AWS_KMS") == "True" + ): + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + if config is not None: """ Allow user to pass in db url via config @@ -459,6 +473,7 @@ def run_server( proxy_config = ProxyConfig() _config = asyncio.run(proxy_config.get_config(config_file_path=config)) + ### LITELLM SETTINGS ### litellm_settings = _config.get("litellm_settings", None) if ( diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py index 49c79b68b6..9e5d5befe8 100644 --- a/litellm/proxy/secret_managers/aws_secret_manager.py +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -62,7 +62,7 @@ def load_aws_kms(use_aws_kms: Optional[bool]): raise e -class AWSKeyManagementService: +class AWSKeyManagementService_V2: """ V2 Clean Class for decrypting keys from AWS KeyManagementService """ @@ -77,6 +77,12 @@ class AWSKeyManagementService: if "AWS_REGION_NAME" not in os.environ: raise ValueError("Missing required environment variable - AWS_REGION_NAME") + ## CHECK IF LICENSE IN ENV ## - premium feature + if os.getenv("LITELLM_LICENSE", None) is None: + raise ValueError( + "AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment." + ) + def load_aws_kms(self, use_aws_kms: Optional[bool]): if use_aws_kms is None or use_aws_kms is False: return @@ -88,8 +94,6 @@ class AWSKeyManagementService: # Create a Secrets Manager client kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) - litellm.secret_manager_client = kms_client - litellm._key_management_system = KeyManagementSystem.AWS_KMS return kms_client except Exception as e: raise e @@ -131,13 +135,13 @@ class AWSKeyManagementService: """ - look for all values in the env with `aws_kms/` - decrypt keys -- rewrite env var with decrypted key +- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist. """ -def decrypt_and_reset_env_var() -> dict: +def decrypt_env_var() -> dict[str, Any]: # setup client class - aws_kms = AWSKeyManagementService() + aws_kms = AWSKeyManagementService_V2() # iterate through env - for `aws_kms/` new_values = {} for k, v in os.environ.items(): From b78dd6416ac05392bf7a5c8823ce0fe2f5278947 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 28 Jun 2024 16:31:37 -0700 Subject: [PATCH 3/3] fix(prisma_migration.py): support decrypting variables in a python script --- entrypoint.sh | 52 ++++------------------- litellm/proxy/prisma_migration.py | 68 +++++++++++++++++++++++++++++++ tests/test_entrypoint.py | 2 + 3 files changed, 79 insertions(+), 43 deletions(-) create mode 100644 litellm/proxy/prisma_migration.py diff --git a/entrypoint.sh b/entrypoint.sh index 6e47dde12c..a028e54262 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,47 +1,13 @@ -#!/bin/sh +#!/bin/bash +echo $(pwd) -# Check if DATABASE_URL is not set -if [ -z "$DATABASE_URL" ]; then - # Check if all required variables are provided - if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then - # Construct DATABASE_URL from the provided variables - DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}" - export DATABASE_URL - else - echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." - exit 1 - fi -fi +# Run the Python migration script +python3 litellm/proxy/prisma_migration.py -# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations -if [ -z "$DIRECT_URL" ]; then - export DIRECT_URL=$DATABASE_URL -fi - -# Apply migrations -retry_count=0 -max_retries=3 -exit_code=1 - -until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -do - retry_count=$((retry_count+1)) - echo "Attempt $retry_count..." - - # Run the Prisma db push command - prisma db push --accept-data-loss - - exit_code=$? - - if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then - echo "Retrying in 10 seconds..." - sleep 10 - fi -done - -if [ $exit_code -ne 0 ]; then - echo "Unable to push database changes after $max_retries retries." +# Check if the Python script executed successfully +if [ $? -eq 0 ]; then + echo "Migration script ran successfully!" +else + echo "Migration script failed!" exit 1 fi - -echo "Database push successful!" diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py new file mode 100644 index 0000000000..6ee09c22b6 --- /dev/null +++ b/litellm/proxy/prisma_migration.py @@ -0,0 +1,68 @@ +# What is this? +## Script to apply initial prisma migration on Docker setup + +import os +import subprocess +import sys +import time + +sys.path.insert( + 0, os.path.abspath("./") +) # Adds the parent directory to the system path +from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + +if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + +# Check if DATABASE_URL is not set +database_url = os.getenv("DATABASE_URL") +if not database_url: + # Check if all required variables are provided + database_host = os.getenv("DATABASE_HOST") + database_username = os.getenv("DATABASE_USERNAME") + database_password = os.getenv("DATABASE_PASSWORD") + database_name = os.getenv("DATABASE_NAME") + + if database_host and database_username and database_password and database_name: + # Construct DATABASE_URL from the provided variables + database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}" + os.environ["DATABASE_URL"] = database_url + else: + print( # noqa + "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa + ) + exit(1) + +# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations +direct_url = os.getenv("DIRECT_URL") +if not direct_url: + os.environ["DIRECT_URL"] = database_url + +# Apply migrations +retry_count = 0 +max_retries = 3 +exit_code = 1 + +while retry_count < max_retries and exit_code != 0: + retry_count += 1 + print(f"Attempt {retry_count}...") # noqa + + # Run the Prisma db push command + result = subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], capture_output=True + ) + exit_code = result.returncode + + if exit_code != 0 and retry_count < max_retries: + print("Retrying in 10 seconds...") # noqa + time.sleep(10) + +if exit_code != 0: + print(f"Unable to push database changes after {max_retries} retries.") # noqa + exit(1) + +print("Database push successful!") # noqa diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py index cbf14c6ead..803135e35d 100644 --- a/tests/test_entrypoint.py +++ b/tests/test_entrypoint.py @@ -12,6 +12,7 @@ import litellm import subprocess +@pytest.mark.skip(reason="local test") def test_decrypt_and_reset_env(): os.environ["DATABASE_URL"] = ( "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" @@ -29,6 +30,7 @@ def test_decrypt_and_reset_env(): print("DATABASE_URL={}".format(os.environ["DATABASE_URL"])) +@pytest.mark.skip(reason="local test") def test_entrypoint_decrypt_and_reset(): os.environ["DATABASE_URL"] = ( "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"