llama-stack/llama_stack/providers/utils/bedrock/refreshable_boto_session.py
Yuan Tang 34ab7a3b6c
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
2025-02-02 06:46:45 -08:00

114 lines
3.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
from time import time
from uuid import uuid4
from boto3 import Session
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
class RefreshableBotoSession:
"""
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
Usage
-----
session = RefreshableBotoSession().refreshable_session()
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
"""
def __init__(
self,
region_name: str = None,
profile_name: str = None,
sts_arn: str = None,
session_name: str = None,
session_ttl: int = 30000,
):
"""
Initialize `RefreshableBotoSession`
Parameters
----------
region_name : str (optional)
Default region when creating a new connection.
profile_name : str (optional)
The name of a profile to use.
sts_arn : str (optional)
The role arn to sts before creating a session.
session_name : str (optional)
An identifier for the assumed role session. (required when `sts_arn` is given)
session_ttl : int (optional)
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
50 minutes by default which is before the default role expiration of 1 hour
"""
self.region_name = region_name
self.profile_name = profile_name
self.sts_arn = sts_arn
self.session_name = session_name or uuid4().hex
self.session_ttl = session_ttl
def __get_session_credentials(self):
"""
Get session credentials
"""
session = Session(region_name=self.region_name, profile_name=self.profile_name)
# if sts_arn is given, get credential by assuming the given role
if self.sts_arn:
sts_client = session.client(service_name="sts", region_name=self.region_name)
response = sts_client.assume_role(
RoleArn=self.sts_arn,
RoleSessionName=self.session_name,
DurationSeconds=self.session_ttl,
).get("Credentials")
credentials = {
"access_key": response.get("AccessKeyId"),
"secret_key": response.get("SecretAccessKey"),
"token": response.get("SessionToken"),
"expiry_time": response.get("Expiration").isoformat(),
}
else:
session_credentials = session.get_credentials().get_frozen_credentials()
credentials = {
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.datetime.fromtimestamp(
time() + self.session_ttl, datetime.timezone.utc
).isoformat(),
}
return credentials
def refreshable_session(self) -> Session:
"""
Get refreshable boto3 session.
"""
# Get refreshable credentials
refreshable_credentials = RefreshableCredentials.create_from_metadata(
metadata=self.__get_session_credentials(),
refresh_using=self.__get_session_credentials,
method="sts-assume-role",
)
# attach refreshable credentials current session
session = get_session()
session._credentials = refreshable_credentials
session.set_config_variable("region", self.region_name)
autorefresh_session = Session(botocore_session=session)
return autorefresh_session