diff --git a/llama_stack/providers/utils/refreshable_boto_session.py b/llama_stack/providers/utils/refreshable_boto_session.py index 789d7c612..a3a19dea6 100644 --- a/llama_stack/providers/utils/refreshable_boto_session.py +++ b/llama_stack/providers/utils/refreshable_boto_session.py @@ -4,9 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os from datetime import datetime from time import time +from uuid import uuid4 import pytz from boto3 import Session @@ -22,14 +22,16 @@ class RefreshableBotoSession: ----- session = RefreshableBotoSession().refreshable_session() - client = session.client("bedrock-runtime") # we now can cache this client object without worrying about expiring credentials + client = session.client("s3") # we now can cache this client object without worrying about expiring credentials """ def __init__( self, region_name: str = None, profile_name: str = None, - session_ttl: int = 30000, + sts_arn: str = None, + session_name: str = None, + session_ttl: int = 3000, ): """ Initialize `RefreshableBotoSession` @@ -37,54 +39,61 @@ class RefreshableBotoSession: Parameters ---------- region_name : str (optional) - Default region when creating a new connection. Will check AWS_REGION or AWS_DEFAULT_REGION env vars if not provided. + Default region when creating a new connection. profile_name : str (optional) - The name of a profile to use. Will check environment variables before using profile. + The name of a profile to use. + + sts_arn : str (optional) + The role arn to sts before creating a session. + + session_name : str (optional) + An identifier for the assumed role session. (required when `sts_arn` is given) + + session_ttl : int (optional) + An integer number to set the TTL for each session. Beyond this session, it will renew the token. + 50 minutes by default which is before the default role expiration of 1 hour """ - # Check environment variables for region - self.region_name = ( - region_name - or os.environ.get("AWS_REGION") - or os.environ.get("AWS_DEFAULT_REGION") - ) + + self.region_name = region_name self.profile_name = profile_name + self.sts_arn = sts_arn + self.session_name = session_name or uuid4().hex self.session_ttl = session_ttl def __get_session_credentials(self): """ - Get session credentials from environment variables or session + Get session credentials """ - # Check for credentials in environment variables first - if all( - key in os.environ for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] - ): - expiry_time = ( - os.environ.get("EXPIRY_TIME") - or datetime.fromtimestamp(time() + self.session_ttl) - .replace(tzinfo=pytz.utc) - .isoformat() - ) - credentials = { - "access_key": os.environ["AWS_ACCESS_KEY_ID"], - "secret_key": os.environ["AWS_SECRET_ACCESS_KEY"], - "token": os.environ.get("AWS_SESSION_TOKEN"), # Optional - "expiry_time": expiry_time, - } - return credentials - - # Fall back to profile-based credentials session = Session(region_name=self.region_name, profile_name=self.profile_name) - session_credentials = session.get_credentials().get_frozen_credentials() - credentials = { - "access_key": session_credentials.access_key, - "secret_key": session_credentials.secret_key, - "token": session_credentials.token, - "expiry_time": datetime.fromtimestamp(time() + self.session_ttl) - .replace(tzinfo=pytz.utc) - .isoformat(), - } + # if sts_arn is given, get credential by assuming the given role + if self.sts_arn: + sts_client = session.client( + service_name="sts", region_name=self.region_name + ) + response = sts_client.assume_role( + RoleArn=self.sts_arn, + RoleSessionName=self.session_name, + DurationSeconds=self.session_ttl, + ).get("Credentials") + + credentials = { + "access_key": response.get("AccessKeyId"), + "secret_key": response.get("SecretAccessKey"), + "token": response.get("SessionToken"), + "expiry_time": response.get("Expiration").isoformat(), + } + else: + session_credentials = session.get_credentials().get_frozen_credentials() + credentials = { + "access_key": session_credentials.access_key, + "secret_key": session_credentials.secret_key, + "token": session_credentials.token, + "expiry_time": datetime.fromtimestamp(time() + self.session_ttl) + .replace(tzinfo=pytz.utc) + .isoformat(), + } return credentials