mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_cli.py): support iam-based auth to rds
Initial pr for iam-based auth support for rds
This commit is contained in:
parent
a6e67d643b
commit
1cc7c7fc59
2 changed files with 36 additions and 0 deletions
11
litellm/proxy/auth/rds_iam_token.py
Normal file
11
litellm/proxy/auth/rds_iam_token.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
def generate_iam_auth_token(db_host, db_port, db_user) -> str:
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
client = boto3.client("rds")
|
||||||
|
token = client.generate_db_auth_token(
|
||||||
|
DBHostname=db_host, Port=db_port, DBUsername=db_user
|
||||||
|
)
|
||||||
|
cleaned_token = quote(token, safe="")
|
||||||
|
return cleaned_token
|
|
@ -171,6 +171,12 @@ def is_port_in_use(port):
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Calls async endpoints /queue/requests and /queue/response",
|
help="Calls async endpoints /queue/requests and /queue/response",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--iam_token_db_auth",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Connects to RDS DB with IAM token",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num_requests",
|
"--num_requests",
|
||||||
default=10,
|
default=10,
|
||||||
|
@ -222,6 +228,7 @@ def run_server(
|
||||||
local,
|
local,
|
||||||
num_workers,
|
num_workers,
|
||||||
test_async,
|
test_async,
|
||||||
|
iam_token_db_auth,
|
||||||
num_requests,
|
num_requests,
|
||||||
use_queue,
|
use_queue,
|
||||||
health,
|
health,
|
||||||
|
@ -442,6 +449,24 @@ def run_server(
|
||||||
|
|
||||||
db_connection_pool_limit = 100
|
db_connection_pool_limit = 100
|
||||||
db_connection_timeout = 60
|
db_connection_timeout = 60
|
||||||
|
### GET DB TOKEN FOR IAM AUTH ###
|
||||||
|
|
||||||
|
if iam_token_db_auth:
|
||||||
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||||
|
|
||||||
|
db_host = os.getenv("DATABASE_HOST")
|
||||||
|
db_port = os.getenv("DATABASE_PORT")
|
||||||
|
db_user = os.getenv("DATABASE_USER")
|
||||||
|
db_name = os.getenv("DATABASE_NAME")
|
||||||
|
|
||||||
|
token = generate_iam_auth_token(
|
||||||
|
db_host=db_host, db_port=db_port, db_user=db_user
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"token: {token}")
|
||||||
|
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
||||||
|
os.environ["DATABASE_URL"] = _db_url
|
||||||
|
|
||||||
### DECRYPT ENV VAR ###
|
### DECRYPT ENV VAR ###
|
||||||
|
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue