diff --git a/entrypoint.sh b/entrypoint.sh index 6e47dde12..a028e5426 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,47 +1,13 @@ -#!/bin/sh +#!/bin/bash +echo $(pwd) -# Check if DATABASE_URL is not set -if [ -z "$DATABASE_URL" ]; then - # Check if all required variables are provided - if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then - # Construct DATABASE_URL from the provided variables - DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}" - export DATABASE_URL - else - echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." - exit 1 - fi -fi +# Run the Python migration script +python3 litellm/proxy/prisma_migration.py -# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations -if [ -z "$DIRECT_URL" ]; then - export DIRECT_URL=$DATABASE_URL -fi - -# Apply migrations -retry_count=0 -max_retries=3 -exit_code=1 - -until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ] -do - retry_count=$((retry_count+1)) - echo "Attempt $retry_count..." - - # Run the Prisma db push command - prisma db push --accept-data-loss - - exit_code=$? - - if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then - echo "Retrying in 10 seconds..." - sleep 10 - fi -done - -if [ $exit_code -ne 0 ]; then - echo "Unable to push database changes after $max_retries retries." +# Check if the Python script executed successfully +if [ $? -eq 0 ]; then + echo "Migration script ran successfully!" +else + echo "Migration script failed!" exit 1 fi - -echo "Database push successful!" diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py new file mode 100644 index 000000000..6ee09c22b --- /dev/null +++ b/litellm/proxy/prisma_migration.py @@ -0,0 +1,68 @@ +# What is this? +## Script to apply initial prisma migration on Docker setup + +import os +import subprocess +import sys +import time + +sys.path.insert( + 0, os.path.abspath("./") +) # Adds the parent directory to the system path +from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var + +if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": + ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV + new_env_var = decrypt_env_var() + + for k, v in new_env_var.items(): + os.environ[k] = v + +# Check if DATABASE_URL is not set +database_url = os.getenv("DATABASE_URL") +if not database_url: + # Check if all required variables are provided + database_host = os.getenv("DATABASE_HOST") + database_username = os.getenv("DATABASE_USERNAME") + database_password = os.getenv("DATABASE_PASSWORD") + database_name = os.getenv("DATABASE_NAME") + + if database_host and database_username and database_password and database_name: + # Construct DATABASE_URL from the provided variables + database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}" + os.environ["DATABASE_URL"] = database_url + else: + print( # noqa + "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa + ) + exit(1) + +# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations +direct_url = os.getenv("DIRECT_URL") +if not direct_url: + os.environ["DIRECT_URL"] = database_url + +# Apply migrations +retry_count = 0 +max_retries = 3 +exit_code = 1 + +while retry_count < max_retries and exit_code != 0: + retry_count += 1 + print(f"Attempt {retry_count}...") # noqa + + # Run the Prisma db push command + result = subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], capture_output=True + ) + exit_code = result.returncode + + if exit_code != 0 and retry_count < max_retries: + print("Retrying in 10 seconds...") # noqa + time.sleep(10) + +if exit_code != 0: + print(f"Unable to push database changes after {max_retries} retries.") # noqa + exit(1) + +print("Database push successful!") # noqa diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py index cbf14c6ea..803135e35 100644 --- a/tests/test_entrypoint.py +++ b/tests/test_entrypoint.py @@ -12,6 +12,7 @@ import litellm import subprocess +@pytest.mark.skip(reason="local test") def test_decrypt_and_reset_env(): os.environ["DATABASE_URL"] = ( "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" @@ -29,6 +30,7 @@ def test_decrypt_and_reset_env(): print("DATABASE_URL={}".format(os.environ["DATABASE_URL"])) +@pytest.mark.skip(reason="local test") def test_entrypoint_decrypt_and_reset(): os.environ["DATABASE_URL"] = ( "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"