diff --git a/distributions/bedrock/compose.yaml b/distributions/bedrock/compose.yaml new file mode 100644 index 000000000..f988e33d1 --- /dev/null +++ b/distributions/bedrock/compose.yaml @@ -0,0 +1,15 @@ +services: + llamastack: + image: distribution-bedrock + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-bedrock.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-bedrock.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/bedrock/run.yaml b/distributions/bedrock/run.yaml new file mode 100644 index 000000000..bd9a89566 --- /dev/null +++ b/distributions/bedrock/run.yaml @@ -0,0 +1,46 @@ +version: '2' +built_at: '2024-11-01T17:40:45.325529' +image_name: local +name: bedrock +docker_image: null +conda_env: local +apis: +- shields +- agents +- models +- memory +- memory_banks +- inference +- safety +providers: + inference: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: + memory: + - provider_id: meta0 + provider_type: meta-reference + config: {} + safety: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: + agents: + - provider_id: meta0 + provider_type: meta-reference + config: + persistence_store: + type: sqlite + db_path: ~/.llama/runtime/kvstore.db + telemetry: + - provider_id: meta0 + provider_type: meta-reference + config: {} diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md new file mode 100644 index 000000000..28691d4e3 --- /dev/null +++ b/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md @@ -0,0 +1,58 @@ +# Bedrock Distribution + +### Connect to a Llama Stack Bedrock Endpoint +- You may connect to Amazon Bedrock APIs for running LLM inference + +The `llamastack/distribution-bedrock` distribution consists of the following provider configurations. + + +| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | +|----------------- |--------------- |---------------- |---------------- |---------------- |---------------- | +| **Provider(s)** | remote::bedrock | meta-reference | meta-reference | remote::bedrock | meta-reference | + + +### Docker: Start the Distribution (Single Node CPU) + +> [!NOTE] +> This assumes you have valid AWS credentials configured with access to Amazon Bedrock. + +``` +$ cd distributions/bedrock && docker compose up +``` + +Make sure in your `run.yaml` file, your inference provider is pointing to the correct AWS configuration. E.g. +``` +inference: + - provider_id: bedrock0 + provider_type: remote::bedrock + config: + aws_access_key_id: + aws_secret_access_key: + aws_session_token: + region_name: +``` + +### Conda llama stack run (Single Node CPU) + +```bash +llama stack build --template bedrock --image-type conda +# -- modify run.yaml with valid AWS credentials +llama stack run ./run.yaml +``` + +### (Optional) Update Model Serving Configuration + +Use `llama-stack-client models list` to check the available models served by Amazon Bedrock. + +``` +$ llama-stack-client models list ++------------------------------+------------------------------+---------------+------------+ +| identifier | llama_model | provider_id | metadata | ++==============================+==============================+===============+============+ +| Llama3.1-8B-Instruct | meta.llama3-1-8b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-70B-Instruct | meta.llama3-1-70b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-405B-Instruct | meta.llama3-1-405b-instruct-v1:0 | bedrock0 | {} | ++------------------------------+------------------------------+---------------+------------+ +``` diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index f3615dc4b..0b74fd259 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel): class ShieldStore(Protocol): - def get_shield(self, identifier: str) -> ShieldDef: ... + async def get_shield(self, identifier: str) -> ShieldDef: ... @runtime_checkable @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 7c8e3939a..fd5634442 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -46,7 +46,7 @@ class Shields(Protocol): async def list_shields(self) -> List[ShieldDefWithProvider]: ... @webmethod(route="/shields/get", method="GET") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ... + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ... @webmethod(route="/shields/register", method="POST") async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 348d8449d..760dbaf2f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - shield_type: str, + identifier: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - return await self.routing_table.get_provider_impl(shield_type).run_shield( - shield_type=shield_type, + return await self.routing_table.get_provider_impl(identifier).run_shield( + identifier=identifier, messages=messages, params=params, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 6297182bc..bcf125bec 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: return await self.get_all_with_type("shield") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: - return await self.get_object_by_identifier(shield_type) + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: + return await self.get_object_by_identifier(identifier) async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index caf886c0b..87b374de1 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -6,9 +6,7 @@ from typing import * # noqa: F403 -import boto3 from botocore.client import BaseClient -from botocore.config import Config from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -16,7 +14,9 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.bedrock.client import create_bedrock_client BEDROCK_SUPPORTED_MODELS = { @@ -34,7 +34,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) self._config = config - self._client = _create_bedrock_client(config) + self._client = create_bedrock_client(config) self.formatter = ChatFormat(Tokenizer.get_instance()) @property @@ -437,43 +437,3 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() - - -def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } - - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } - - boto3_config = Config(**config_args) - - session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None - } - - boto3_session = boto3.session.Session(**session_args) - - return boto3_session.client("bedrock-runtime", config=boto3_config) diff --git a/llama_stack/providers/adapters/inference/bedrock/config.py b/llama_stack/providers/adapters/inference/bedrock/config.py index 72d2079b9..8e194700c 100644 --- a/llama_stack/providers/adapters/inference/bedrock/config.py +++ b/llama_stack/providers/adapters/inference/bedrock/config.py @@ -3,53 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import * # noqa: F403 from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig @json_schema_type -class BedrockConfig(BaseModel): - aws_access_key_id: Optional[str] = Field( - default=None, - description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", - ) - aws_secret_access_key: Optional[str] = Field( - default=None, - description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", - ) - aws_session_token: Optional[str] = Field( - default=None, - description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", - ) - region_name: Optional[str] = Field( - default=None, - description="The default AWS Region to use, for example, us-west-1 or us-west-2." - "Default use environment variable: AWS_DEFAULT_REGION", - ) - profile_name: Optional[str] = Field( - default=None, - description="The profile name that contains credentials to use." - "Default use environment variable: AWS_PROFILE", - ) - total_max_attempts: Optional[int] = Field( - default=None, - description="An integer representing the maximum number of attempts that will be made for a single request, " - "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", - ) - retry_mode: Optional[str] = Field( - default=None, - description="A string representing the type of retries Boto3 will perform." - "Default use environment variable: AWS_RETRY_MODE", - ) - connect_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " - "The default is 60 seconds.", - ) - read_timeout: Optional[float] = Field( - default=60, - description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." - "The default is 60 seconds.", - ) +class BedrockConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 3203e36f4..e14dbd2a4 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -9,11 +9,10 @@ import logging from typing import Any, Dict, List -import boto3 - from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig @@ -28,17 +27,13 @@ BEDROCK_SUPPORTED_SHIELDS = [ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: - if not config.aws_profile: - raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config self.registered_shields = [] async def initialize(self) -> None: try: - print(f"initializing with profile --- > {self.config}") - self.boto_client = boto3.Session( - profile_name=self.config.aws_profile - ).client("bedrock-runtime") + self.bedrock_runtime_client = create_bedrock_client(self.config) + self.bedrock_client = create_bedrock_client(self.config, "bedrock") except Exception as e: raise RuntimeError("Error initializing BedrockSafetyAdapter") from e @@ -49,19 +44,28 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): raise ValueError("Registering dynamic shields is not supported") async def list_shields(self) -> List[ShieldDef]: - raise NotImplementedError( - """ - `list_shields` not implemented; this should read all guardrails from - bedrock and populate guardrailId and guardrailVersion in the ShieldDef. - """ - ) + response = self.bedrock_client.list_guardrails() + shields = [] + for guardrail in response["guardrails"]: + # populate the shield def with the guardrail id and version + shield_def = ShieldDef( + identifier=guardrail["id"], + shield_type=ShieldType.generic_content_shield.value, + params={ + "guardrailIdentifier": guardrail["id"], + "guardrailVersion": guardrail["version"], + }, + ) + self.registered_shields.append(shield_def) + shields.append(shield_def) + return shields async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ @@ -88,7 +92,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" ) - response = self.boto_client.apply_guardrail( + response = self.bedrock_runtime_client.apply_guardrail( guardrailIdentifier=shield_params["guardrailIdentifier"], guardrailVersion=shield_params["guardrailVersion"], source="OUTPUT", # or 'INPUT' depending on your use case @@ -104,10 +108,12 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): # guardrails returns a list - however for this implementation we will leverage the last values metadata = dict(assessment) - return SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) ) - return None + return RunShieldResponse() diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py index 2a8585262..8c61decf3 100644 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field + +from llama_models.schema_utils import json_schema_type + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig -class BedrockSafetyConfig(BaseModel): - """Configuration information for a guardrail that you want to use in the request.""" - - aws_profile: str = Field( - default="default", - description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", - ) +@json_schema_type +class BedrockSafetyConfig(BedrockBaseConfig): + pass diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index da45ed5b8..9f92626af 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -43,11 +43,11 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat ] async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") model = shield_def.params.get("model", "llama_guard") if model not in TOGETHER_SHIELD_MODEL_MAP: diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index fb5821f6a..915ddd303 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -32,18 +32,18 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shield_types: List[str] + self, messages: List[Message], identifiers: List[str] ) -> None: responses = await asyncio.gather( *[ self.safety_api.run_shield( - shield_type=shield_type, + identifier=identifier, messages=messages, ) - for shield_type in shield_types + for identifier in identifiers ] ) - for shield_type, response in zip(shield_types, responses): + for identifier, response in zip(identifiers, responses): if not response.violation: continue @@ -52,6 +52,6 @@ class ShieldRunnerMixin: raise SafetyException(violation) elif violation.violation_level == ViolationLevel.WARN: cprint( - f"[Warn]{shield_type} raised a warning", + f"[Warn]{identifier} raised a warning", color="red", ) diff --git a/llama_stack/providers/utils/bedrock/client.py b/llama_stack/providers/utils/bedrock/client.py new file mode 100644 index 000000000..77781c729 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/client.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import boto3 +from botocore.client import BaseClient +from botocore.config import Config + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +from llama_stack.providers.utils.bedrock.refreshable_boto_session import ( + RefreshableBotoSession, +) + + +def create_bedrock_client( + config: BedrockBaseConfig, service_name: str = "bedrock-runtime" +) -> BaseClient: + """Creates a boto3 client for Bedrock services with the given configuration. + + Args: + config: The Bedrock configuration containing AWS credentials and settings + service_name: The AWS service name to create client for (default: "bedrock-runtime") + + Returns: + A configured boto3 client + """ + if config.aws_access_key_id and config.aws_secret_access_key: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } + + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } + + boto3_config = Config(**config_args) + + session_args = { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + "aws_session_token": config.aws_session_token, + "region_name": config.region_name, + "profile_name": config.profile_name, + "session_ttl": config.session_ttl, + } + + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} + + boto3_session = boto3.session.Session(**session_args) + return boto3_session.client(service_name, config=boto3_config) + else: + return ( + RefreshableBotoSession( + region_name=config.region_name, + profile_name=config.profile_name, + session_ttl=config.session_ttl, + ) + .refreshable_session() + .client(service_name) + ) diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py new file mode 100644 index 000000000..55c5582a1 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/config.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class BedrockBaseConfig(BaseModel): + aws_access_key_id: Optional[str] = Field( + default=None, + description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", + ) + aws_secret_access_key: Optional[str] = Field( + default=None, + description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", + ) + aws_session_token: Optional[str] = Field( + default=None, + description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", + ) + region_name: Optional[str] = Field( + default=None, + description="The default AWS Region to use, for example, us-west-1 or us-west-2." + "Default use environment variable: AWS_DEFAULT_REGION", + ) + profile_name: Optional[str] = Field( + default=None, + description="The profile name that contains credentials to use." + "Default use environment variable: AWS_PROFILE", + ) + total_max_attempts: Optional[int] = Field( + default=None, + description="An integer representing the maximum number of attempts that will be made for a single request, " + "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", + ) + retry_mode: Optional[str] = Field( + default=None, + description="A string representing the type of retries Boto3 will perform." + "Default use environment variable: AWS_RETRY_MODE", + ) + connect_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " + "The default is 60 seconds.", + ) + read_timeout: Optional[float] = Field( + default=60, + description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." + "The default is 60 seconds.", + ) + session_ttl: Optional[int] = Field( + default=3600, + description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", + ) diff --git a/llama_stack/providers/utils/bedrock/refreshable_boto_session.py b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py new file mode 100644 index 000000000..f37563930 --- /dev/null +++ b/llama_stack/providers/utils/bedrock/refreshable_boto_session.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import datetime +from time import time +from uuid import uuid4 + +from boto3 import Session +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session + + +class RefreshableBotoSession: + """ + Boto Helper class which lets us create a refreshable session so that we can cache the client or resource. + + Usage + ----- + session = RefreshableBotoSession().refreshable_session() + + client = session.client("s3") # we now can cache this client object without worrying about expiring credentials + """ + + def __init__( + self, + region_name: str = None, + profile_name: str = None, + sts_arn: str = None, + session_name: str = None, + session_ttl: int = 30000, + ): + """ + Initialize `RefreshableBotoSession` + + Parameters + ---------- + region_name : str (optional) + Default region when creating a new connection. + + profile_name : str (optional) + The name of a profile to use. + + sts_arn : str (optional) + The role arn to sts before creating a session. + + session_name : str (optional) + An identifier for the assumed role session. (required when `sts_arn` is given) + + session_ttl : int (optional) + An integer number to set the TTL for each session. Beyond this session, it will renew the token. + 50 minutes by default which is before the default role expiration of 1 hour + """ + + self.region_name = region_name + self.profile_name = profile_name + self.sts_arn = sts_arn + self.session_name = session_name or uuid4().hex + self.session_ttl = session_ttl + + def __get_session_credentials(self): + """ + Get session credentials + """ + session = Session(region_name=self.region_name, profile_name=self.profile_name) + + # if sts_arn is given, get credential by assuming the given role + if self.sts_arn: + sts_client = session.client( + service_name="sts", region_name=self.region_name + ) + response = sts_client.assume_role( + RoleArn=self.sts_arn, + RoleSessionName=self.session_name, + DurationSeconds=self.session_ttl, + ).get("Credentials") + + credentials = { + "access_key": response.get("AccessKeyId"), + "secret_key": response.get("SecretAccessKey"), + "token": response.get("SessionToken"), + "expiry_time": response.get("Expiration").isoformat(), + } + else: + session_credentials = session.get_credentials().get_frozen_credentials() + credentials = { + "access_key": session_credentials.access_key, + "secret_key": session_credentials.secret_key, + "token": session_credentials.token, + "expiry_time": datetime.datetime.fromtimestamp( + time() + self.session_ttl, datetime.timezone.utc + ).isoformat(), + } + + return credentials + + def refreshable_session(self) -> Session: + """ + Get refreshable boto3 session. + """ + # Get refreshable credentials + refreshable_credentials = RefreshableCredentials.create_from_metadata( + metadata=self.__get_session_credentials(), + refresh_using=self.__get_session_credentials, + method="sts-assume-role", + ) + + # attach refreshable credentials current session + session = get_session() + session._credentials = refreshable_credentials + session.set_config_variable("region", self.region_name) + autorefresh_session = Session(botocore_session=session) + + return autorefresh_session