diff --git a/litellm/proxy/auth/rds_iam_token.py b/litellm/proxy/auth/rds_iam_token.py new file mode 100644 index 0000000000..719e16ac05 --- /dev/null +++ b/litellm/proxy/auth/rds_iam_token.py @@ -0,0 +1,11 @@ +def generate_iam_auth_token(db_host, db_port, db_user) -> str: + from urllib.parse import quote + + import boto3 + + client = boto3.client("rds") + token = client.generate_db_auth_token( + DBHostname=db_host, Port=db_port, DBUsername=db_user + ) + cleaned_token = quote(token, safe="") + return cleaned_token diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 1d0eef6a0e..af78085a4c 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -171,6 +171,12 @@ def is_port_in_use(port): is_flag=True, help="Calls async endpoints /queue/requests and /queue/response", ) +@click.option( + "--iam_token_db_auth", + default=False, + is_flag=True, + help="Connects to RDS DB with IAM token", +) @click.option( "--num_requests", default=10, @@ -222,6 +228,7 @@ def run_server( local, num_workers, test_async, + iam_token_db_auth, num_requests, use_queue, health, @@ -442,6 +449,24 @@ def run_server( db_connection_pool_limit = 100 db_connection_timeout = 60 + ### GET DB TOKEN FOR IAM AUTH ### + + if iam_token_db_auth: + from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token + + db_host = os.getenv("DATABASE_HOST") + db_port = os.getenv("DATABASE_PORT") + db_user = os.getenv("DATABASE_USER") + db_name = os.getenv("DATABASE_NAME") + + token = generate_iam_auth_token( + db_host=db_host, db_port=db_port, db_user=db_user + ) + + # print(f"token: {token}") + _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}" + os.environ["DATABASE_URL"] = _db_url + ### DECRYPT ENV VAR ### from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var