remove env vars

This commit is contained in:
Dinesh Yeduguru 2024-11-06 12:26:22 -08:00
parent 6697ca3d3a
commit 2101cb08c7

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from datetime import datetime
from time import time
from uuid import uuid4
import pytz
from boto3 import Session
@ -22,14 +22,16 @@ class RefreshableBotoSession:
-----
session = RefreshableBotoSession().refreshable_session()
client = session.client("bedrock-runtime") # we now can cache this client object without worrying about expiring credentials
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
"""
def __init__(
self,
region_name: str = None,
profile_name: str = None,
session_ttl: int = 30000,
sts_arn: str = None,
session_name: str = None,
session_ttl: int = 3000,
):
"""
Initialize `RefreshableBotoSession`
@ -37,54 +39,61 @@ class RefreshableBotoSession:
Parameters
----------
region_name : str (optional)
Default region when creating a new connection. Will check AWS_REGION or AWS_DEFAULT_REGION env vars if not provided.
Default region when creating a new connection.
profile_name : str (optional)
The name of a profile to use. Will check environment variables before using profile.
The name of a profile to use.
sts_arn : str (optional)
The role arn to sts before creating a session.
session_name : str (optional)
An identifier for the assumed role session. (required when `sts_arn` is given)
session_ttl : int (optional)
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
50 minutes by default which is before the default role expiration of 1 hour
"""
# Check environment variables for region
self.region_name = (
region_name
or os.environ.get("AWS_REGION")
or os.environ.get("AWS_DEFAULT_REGION")
)
self.region_name = region_name
self.profile_name = profile_name
self.sts_arn = sts_arn
self.session_name = session_name or uuid4().hex
self.session_ttl = session_ttl
def __get_session_credentials(self):
"""
Get session credentials from environment variables or session
Get session credentials
"""
# Check for credentials in environment variables first
if all(
key in os.environ for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]
):
expiry_time = (
os.environ.get("EXPIRY_TIME")
or datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat()
)
credentials = {
"access_key": os.environ["AWS_ACCESS_KEY_ID"],
"secret_key": os.environ["AWS_SECRET_ACCESS_KEY"],
"token": os.environ.get("AWS_SESSION_TOKEN"), # Optional
"expiry_time": expiry_time,
}
return credentials
# Fall back to profile-based credentials
session = Session(region_name=self.region_name, profile_name=self.profile_name)
session_credentials = session.get_credentials().get_frozen_credentials()
credentials = {
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat(),
}
# if sts_arn is given, get credential by assuming the given role
if self.sts_arn:
sts_client = session.client(
service_name="sts", region_name=self.region_name
)
response = sts_client.assume_role(
RoleArn=self.sts_arn,
RoleSessionName=self.session_name,
DurationSeconds=self.session_ttl,
).get("Credentials")
credentials = {
"access_key": response.get("AccessKeyId"),
"secret_key": response.get("SecretAccessKey"),
"token": response.get("SessionToken"),
"expiry_time": response.get("Expiration").isoformat(),
}
else:
session_credentials = session.get_credentials().get_frozen_credentials()
credentials = {
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat(),
}
return credentials