mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5057 from BerriAI/litellm_rds_iam_auth
feat(proxy_cli.py): support iam-based auth to rds
This commit is contained in:
commit
036a6821d5
2 changed files with 204 additions and 0 deletions
179
litellm/proxy/auth/rds_iam_token.py
Normal file
179
litellm/proxy/auth/rds_iam_token.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
def init_rds_client(
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
):
|
||||||
|
from litellm import get_secret
|
||||||
|
|
||||||
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
# Define the list of parameters to check
|
||||||
|
params_to_check = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
params_to_check[i] = get_secret(param)
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### SET REGION NAME
|
||||||
|
region_name = aws_region_name
|
||||||
|
if aws_region_name:
|
||||||
|
region_name = aws_region_name
|
||||||
|
elif litellm_aws_region_name:
|
||||||
|
region_name = litellm_aws_region_name
|
||||||
|
elif standard_aws_region_name:
|
||||||
|
region_name = standard_aws_region_name
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||||
|
)
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
if isinstance(timeout, float):
|
||||||
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||||
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
|
config = boto3.session.Config(
|
||||||
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = boto3.session.Config()
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if (
|
||||||
|
aws_web_identity_token is not None
|
||||||
|
and aws_role_name is not None
|
||||||
|
and aws_session_name is not None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
oidc_token = open(aws_web_identity_token).read() # check if filepath
|
||||||
|
except Exception:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise Exception(
|
||||||
|
"OIDC token could not be retrieved from secret manager.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client("sts")
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="rds",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
|
# use sts if role name passed in
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="rds",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_access_key_id is not None:
|
||||||
|
# uses auth params passed to completion
|
||||||
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="rds",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=region_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_profile_name is not None:
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||||
|
service_name="rds",
|
||||||
|
region_name=region_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||||
|
# boto3 automatically reads env variables
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=region_name,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def generate_iam_auth_token(db_host, db_port, db_user) -> str:
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
boto_client = init_rds_client(
|
||||||
|
aws_region_name=os.getenv("AWS_REGION_NAME"),
|
||||||
|
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||||
|
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
aws_session_name=os.getenv("AWS_SESSION_NAME"),
|
||||||
|
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
|
||||||
|
aws_role_name=os.getenv("AWS_ROLE_NAME"),
|
||||||
|
aws_web_identity_token=os.getenv("AWS_WEB_IDENTITY_TOKEN"),
|
||||||
|
)
|
||||||
|
|
||||||
|
token = boto_client.generate_db_auth_token(
|
||||||
|
DBHostname=db_host, Port=db_port, DBUsername=db_user
|
||||||
|
)
|
||||||
|
cleaned_token = quote(token, safe="")
|
||||||
|
return cleaned_token
|
|
@ -177,6 +177,12 @@ def is_port_in_use(port):
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Calls async endpoints /queue/requests and /queue/response",
|
help="Calls async endpoints /queue/requests and /queue/response",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--iam_token_db_auth",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Connects to RDS DB with IAM token",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num_requests",
|
"--num_requests",
|
||||||
default=10,
|
default=10,
|
||||||
|
@ -228,6 +234,7 @@ def run_server(
|
||||||
local,
|
local,
|
||||||
num_workers,
|
num_workers,
|
||||||
test_async,
|
test_async,
|
||||||
|
iam_token_db_auth,
|
||||||
num_requests,
|
num_requests,
|
||||||
use_queue,
|
use_queue,
|
||||||
health,
|
health,
|
||||||
|
@ -448,6 +455,24 @@ def run_server(
|
||||||
|
|
||||||
db_connection_pool_limit = 100
|
db_connection_pool_limit = 100
|
||||||
db_connection_timeout = 60
|
db_connection_timeout = 60
|
||||||
|
### GET DB TOKEN FOR IAM AUTH ###
|
||||||
|
|
||||||
|
if iam_token_db_auth:
|
||||||
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||||
|
|
||||||
|
db_host = os.getenv("DATABASE_HOST")
|
||||||
|
db_port = os.getenv("DATABASE_PORT")
|
||||||
|
db_user = os.getenv("DATABASE_USER")
|
||||||
|
db_name = os.getenv("DATABASE_NAME")
|
||||||
|
|
||||||
|
token = generate_iam_auth_token(
|
||||||
|
db_host=db_host, db_port=db_port, db_user=db_user
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"token: {token}")
|
||||||
|
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
||||||
|
os.environ["DATABASE_URL"] = _db_url
|
||||||
|
|
||||||
### DECRYPT ENV VAR ###
|
### DECRYPT ENV VAR ###
|
||||||
|
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue