mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #4437 from BerriAI/litellm_aws_kms_decryption
fix(initial-commit): decrypts aws keys in entrypoint.sh
This commit is contained in:
commit
d0c89ddbe3
5 changed files with 247 additions and 46 deletions
|
@ -1,48 +1,13 @@
|
||||||
#!/bin/sh
|
#!/bin/bash
|
||||||
|
echo $(pwd)
|
||||||
|
|
||||||
# Check if DATABASE_URL is not set
|
# Run the Python migration script
|
||||||
if [ -z "$DATABASE_URL" ]; then
|
python3 litellm/proxy/prisma_migration.py
|
||||||
# Check if all required variables are provided
|
|
||||||
if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then
|
# Check if the Python script executed successfully
|
||||||
# Construct DATABASE_URL from the provided variables
|
if [ $? -eq 0 ]; then
|
||||||
DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}"
|
echo "Migration script ran successfully!"
|
||||||
export DATABASE_URL
|
|
||||||
else
|
else
|
||||||
echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL."
|
echo "Migration script failed!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
fi
|
|
||||||
|
|
||||||
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
|
||||||
if [ -z "$DIRECT_URL" ]; then
|
|
||||||
export DIRECT_URL=$DATABASE_URL
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Apply migrations
|
|
||||||
retry_count=0
|
|
||||||
max_retries=3
|
|
||||||
exit_code=1
|
|
||||||
|
|
||||||
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
|
||||||
do
|
|
||||||
retry_count=$((retry_count+1))
|
|
||||||
echo "Attempt $retry_count..."
|
|
||||||
|
|
||||||
# Run the Prisma db push command
|
|
||||||
prisma db push --accept-data-loss
|
|
||||||
|
|
||||||
exit_code=$?
|
|
||||||
|
|
||||||
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
|
||||||
echo "Retrying in 10 seconds..."
|
|
||||||
sleep 10
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
if [ $exit_code -ne 0 ]; then
|
|
||||||
echo "Unable to push database changes after $max_retries retries."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Database push successful!"
|
|
||||||
|
|
||||||
|
|
68
litellm/proxy/prisma_migration.py
Normal file
68
litellm/proxy/prisma_migration.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
# What is this?
|
||||||
|
## Script to apply initial prisma migration on Docker setup
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("./")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
||||||
|
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
|
||||||
|
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||||
|
new_env_var = decrypt_env_var()
|
||||||
|
|
||||||
|
for k, v in new_env_var.items():
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
|
# Check if DATABASE_URL is not set
|
||||||
|
database_url = os.getenv("DATABASE_URL")
|
||||||
|
if not database_url:
|
||||||
|
# Check if all required variables are provided
|
||||||
|
database_host = os.getenv("DATABASE_HOST")
|
||||||
|
database_username = os.getenv("DATABASE_USERNAME")
|
||||||
|
database_password = os.getenv("DATABASE_PASSWORD")
|
||||||
|
database_name = os.getenv("DATABASE_NAME")
|
||||||
|
|
||||||
|
if database_host and database_username and database_password and database_name:
|
||||||
|
# Construct DATABASE_URL from the provided variables
|
||||||
|
database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}"
|
||||||
|
os.environ["DATABASE_URL"] = database_url
|
||||||
|
else:
|
||||||
|
print( # noqa
|
||||||
|
"Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
||||||
|
direct_url = os.getenv("DIRECT_URL")
|
||||||
|
if not direct_url:
|
||||||
|
os.environ["DIRECT_URL"] = database_url
|
||||||
|
|
||||||
|
# Apply migrations
|
||||||
|
retry_count = 0
|
||||||
|
max_retries = 3
|
||||||
|
exit_code = 1
|
||||||
|
|
||||||
|
while retry_count < max_retries and exit_code != 0:
|
||||||
|
retry_count += 1
|
||||||
|
print(f"Attempt {retry_count}...") # noqa
|
||||||
|
|
||||||
|
# Run the Prisma db push command
|
||||||
|
result = subprocess.run(
|
||||||
|
["prisma", "db", "push", "--accept-data-loss"], capture_output=True
|
||||||
|
)
|
||||||
|
exit_code = result.returncode
|
||||||
|
|
||||||
|
if exit_code != 0 and retry_count < max_retries:
|
||||||
|
print("Retrying in 10 seconds...") # noqa
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
if exit_code != 0:
|
||||||
|
print(f"Unable to push database changes after {max_retries} retries.") # noqa
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Database push successful!") # noqa
|
|
@ -442,6 +442,20 @@ def run_server(
|
||||||
|
|
||||||
db_connection_pool_limit = 100
|
db_connection_pool_limit = 100
|
||||||
db_connection_timeout = 60
|
db_connection_timeout = 60
|
||||||
|
### DECRYPT ENV VAR ###
|
||||||
|
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
||||||
|
if (
|
||||||
|
os.getenv("USE_AWS_KMS", None) is not None
|
||||||
|
and os.getenv("USE_AWS_KMS") == "True"
|
||||||
|
):
|
||||||
|
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||||
|
new_env_var = decrypt_env_var()
|
||||||
|
|
||||||
|
for k, v in new_env_var.items():
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
"""
|
"""
|
||||||
Allow user to pass in db url via config
|
Allow user to pass in db url via config
|
||||||
|
@ -459,6 +473,7 @@ def run_server(
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
||||||
|
|
||||||
### LITELLM SETTINGS ###
|
### LITELLM SETTINGS ###
|
||||||
litellm_settings = _config.get("litellm_settings", None)
|
litellm_settings = _config.get("litellm_settings", None)
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -8,9 +8,12 @@ Requires:
|
||||||
* `pip install boto3>=1.28.57`
|
* `pip install boto3>=1.28.57`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import litellm
|
import ast
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.proxy._types import KeyManagementSystem
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,3 +60,94 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class AWSKeyManagementService_V2:
|
||||||
|
"""
|
||||||
|
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.validate_environment()
|
||||||
|
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
## CHECK IF LICENSE IN ENV ## - premium feature
|
||||||
|
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||||
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
validate_environment()
|
||||||
|
|
||||||
|
# Create a Secrets Manager client
|
||||||
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||||
|
|
||||||
|
return kms_client
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def decrypt_value(self, secret_name: str) -> Any:
|
||||||
|
if self.kms_client is None:
|
||||||
|
raise ValueError("kms_client is None")
|
||||||
|
encrypted_value = os.getenv(secret_name, None)
|
||||||
|
if encrypted_value is None:
|
||||||
|
raise Exception(
|
||||||
|
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||||
|
)
|
||||||
|
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
||||||
|
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
||||||
|
|
||||||
|
# Decode the base64 encoded ciphertext
|
||||||
|
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||||
|
|
||||||
|
# Set up the parameters for the decrypt call
|
||||||
|
params = {"CiphertextBlob": ciphertext_blob}
|
||||||
|
# Perform the decryption
|
||||||
|
response = self.kms_client.decrypt(**params)
|
||||||
|
|
||||||
|
# Extract and decode the plaintext
|
||||||
|
plaintext = response["Plaintext"]
|
||||||
|
secret = plaintext.decode("utf-8")
|
||||||
|
if isinstance(secret, str):
|
||||||
|
secret = secret.strip()
|
||||||
|
try:
|
||||||
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
if isinstance(secret_value_as_bool, bool):
|
||||||
|
return secret_value_as_bool
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return secret
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||||
|
- decrypt keys
|
||||||
|
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_env_var() -> dict[str, Any]:
|
||||||
|
# setup client class
|
||||||
|
aws_kms = AWSKeyManagementService_V2()
|
||||||
|
# iterate through env - for `aws_kms/`
|
||||||
|
new_values = {}
|
||||||
|
for k, v in os.environ.items():
|
||||||
|
if v is not None and isinstance(v, str) and v.startswith("aws_kms/"):
|
||||||
|
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||||
|
# reset env var
|
||||||
|
new_values[k] = decrypted_value
|
||||||
|
|
||||||
|
return new_values
|
||||||
|
|
59
tests/test_entrypoint.py
Normal file
59
tests/test_entrypoint.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for 'entrypoint.sh'
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
|
def test_decrypt_and_reset_env():
|
||||||
|
os.environ["DATABASE_URL"] = (
|
||||||
|
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||||
|
)
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
|
decrypt_and_reset_env_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
decrypt_and_reset_env_var()
|
||||||
|
|
||||||
|
assert os.environ["DATABASE_URL"] is not None
|
||||||
|
assert isinstance(os.environ["DATABASE_URL"], str)
|
||||||
|
assert not os.environ["DATABASE_URL"].startswith("aws_kms/")
|
||||||
|
|
||||||
|
print("DATABASE_URL={}".format(os.environ["DATABASE_URL"]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
|
def test_entrypoint_decrypt_and_reset():
|
||||||
|
os.environ["DATABASE_URL"] = (
|
||||||
|
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||||
|
)
|
||||||
|
command = "./entrypoint.sh"
|
||||||
|
directory = ".." # Relative to the current directory
|
||||||
|
|
||||||
|
# Run the command using subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
command, shell=True, cwd=directory, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the output for debugging purposes
|
||||||
|
print("STDOUT:", result.stdout)
|
||||||
|
print("STDERR:", result.stderr)
|
||||||
|
|
||||||
|
# Assert the script ran successfully
|
||||||
|
assert result.returncode == 0, "The shell script did not execute successfully"
|
||||||
|
assert (
|
||||||
|
"DECRYPTS VALUE" in result.stdout
|
||||||
|
), "Expected output not found in script output"
|
||||||
|
assert (
|
||||||
|
"Database push successful!" in result.stdout
|
||||||
|
), "Expected output not found in script output"
|
||||||
|
|
||||||
|
assert False
|
Loading…
Add table
Add a link
Reference in a new issue