mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 17:23:00 +00:00 
			
		
		
		
	# What does this PR do? dropped python3.10, updated pyproject and dependencies, and also removed some blocks of code with special handling for enum.StrEnum Closes #2458 Signed-off-by: Charlie Doern <cdoern@redhat.com>
		
			
				
	
	
		
			112 lines
		
	
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			112 lines
		
	
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import datetime
 | |
| from time import time
 | |
| from uuid import uuid4
 | |
| 
 | |
| from boto3 import Session
 | |
| from botocore.credentials import RefreshableCredentials
 | |
| from botocore.session import get_session
 | |
| 
 | |
| 
 | |
| class RefreshableBotoSession:
 | |
|     """
 | |
|     Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
 | |
| 
 | |
|     Usage
 | |
|     -----
 | |
|     session = RefreshableBotoSession().refreshable_session()
 | |
| 
 | |
|     client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         region_name: str = None,
 | |
|         profile_name: str = None,
 | |
|         sts_arn: str = None,
 | |
|         session_name: str = None,
 | |
|         session_ttl: int = 30000,
 | |
|     ):
 | |
|         """
 | |
|         Initialize `RefreshableBotoSession`
 | |
| 
 | |
|         Parameters
 | |
|         ----------
 | |
|         region_name : str (optional)
 | |
|             Default region when creating a new connection.
 | |
| 
 | |
|         profile_name : str (optional)
 | |
|             The name of a profile to use.
 | |
| 
 | |
|         sts_arn : str (optional)
 | |
|             The role arn to sts before creating a session.
 | |
| 
 | |
|         session_name : str (optional)
 | |
|             An identifier for the assumed role session. (required when `sts_arn` is given)
 | |
| 
 | |
|         session_ttl : int (optional)
 | |
|             An integer number to set the TTL for each session. Beyond this session, it will renew the token.
 | |
|             50 minutes by default which is before the default role expiration of 1 hour
 | |
|         """
 | |
| 
 | |
|         self.region_name = region_name
 | |
|         self.profile_name = profile_name
 | |
|         self.sts_arn = sts_arn
 | |
|         self.session_name = session_name or uuid4().hex
 | |
|         self.session_ttl = session_ttl
 | |
| 
 | |
|     def __get_session_credentials(self):
 | |
|         """
 | |
|         Get session credentials
 | |
|         """
 | |
|         session = Session(region_name=self.region_name, profile_name=self.profile_name)
 | |
| 
 | |
|         # if sts_arn is given, get credential by assuming the given role
 | |
|         if self.sts_arn:
 | |
|             sts_client = session.client(service_name="sts", region_name=self.region_name)
 | |
|             response = sts_client.assume_role(
 | |
|                 RoleArn=self.sts_arn,
 | |
|                 RoleSessionName=self.session_name,
 | |
|                 DurationSeconds=self.session_ttl,
 | |
|             ).get("Credentials")
 | |
| 
 | |
|             credentials = {
 | |
|                 "access_key": response.get("AccessKeyId"),
 | |
|                 "secret_key": response.get("SecretAccessKey"),
 | |
|                 "token": response.get("SessionToken"),
 | |
|                 "expiry_time": response.get("Expiration").isoformat(),
 | |
|             }
 | |
|         else:
 | |
|             session_credentials = session.get_credentials().get_frozen_credentials()
 | |
|             credentials = {
 | |
|                 "access_key": session_credentials.access_key,
 | |
|                 "secret_key": session_credentials.secret_key,
 | |
|                 "token": session_credentials.token,
 | |
|                 "expiry_time": datetime.datetime.fromtimestamp(time() + self.session_ttl, datetime.UTC).isoformat(),
 | |
|             }
 | |
| 
 | |
|         return credentials
 | |
| 
 | |
|     def refreshable_session(self) -> Session:
 | |
|         """
 | |
|         Get refreshable boto3 session.
 | |
|         """
 | |
|         # Get refreshable credentials
 | |
|         refreshable_credentials = RefreshableCredentials.create_from_metadata(
 | |
|             metadata=self.__get_session_credentials(),
 | |
|             refresh_using=self.__get_session_credentials,
 | |
|             method="sts-assume-role",
 | |
|         )
 | |
| 
 | |
|         # attach refreshable credentials current session
 | |
|         session = get_session()
 | |
|         session._credentials = refreshable_credentials
 | |
|         session.set_config_variable("region", self.region_name)
 | |
|         autorefresh_session = Session(botocore_session=session)
 | |
| 
 | |
|         return autorefresh_session
 |