From 6697ca3d3a81760180a41ef5e06e071330026b7f Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 6 Nov 2024 06:59:55 -0800 Subject: [PATCH] refereshable boto credentials --- .../adapters/inference/bedrock/bedrock.py | 70 +++++++----- .../adapters/safety/bedrock/bedrock.py | 33 ++++-- llama_stack/providers/registry/inference.py | 2 +- llama_stack/providers/registry/safety.py | 2 +- .../utils/refreshable_boto_session.py | 108 ++++++++++++++++++ 5 files changed, 171 insertions(+), 44 deletions(-) create mode 100644 llama_stack/providers/utils/refreshable_boto_session.py diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 7258e4e7d..cc149084a 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession BEDROCK_SUPPORTED_MODELS = { @@ -441,38 +442,47 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def _create_bedrock_client(config: BedrockConfig) -> BaseClient: - retries_config = { - k: v - for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, - ).items() - if v is not None - } + if config.aws_access_key_id and config.aws_secret_access_key: + retries_config = { + k: v + for k, v in dict( + total_max_attempts=config.total_max_attempts, + mode=config.retry_mode, + ).items() + if v is not None + } - config_args = { - k: v - for k, v in dict( - region_name=config.region_name, - retries=retries_config if retries_config else None, - connect_timeout=config.connect_timeout, - read_timeout=config.read_timeout, - ).items() - if v is not None - } + config_args = { + k: v + for k, v in dict( + region_name=config.region_name, + retries=retries_config if retries_config else None, + connect_timeout=config.connect_timeout, + read_timeout=config.read_timeout, + ).items() + if v is not None + } - boto3_config = Config(**config_args) + boto3_config = Config(**config_args) - session_args = { - "aws_access_key_id": config.aws_access_key_id, - "aws_secret_access_key": config.aws_secret_access_key, - "aws_session_token": config.aws_session_token, - "region_name": config.region_name, - "profile_name": config.profile_name, - } + session_args = { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + "aws_session_token": config.aws_session_token, + "region_name": config.region_name, + "profile_name": config.profile_name, + } - # Remove None values - session_args = {k: v for k, v in session_args.items() if v is not None} + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} - boto3_session = boto3.session.Session(**session_args) - return boto3_session.client("bedrock-runtime", config=boto3_config) + boto3_session = boto3.session.Session(**session_args) + return boto3_session.client("bedrock-runtime", config=boto3_config) + else: + return ( + RefreshableBotoSession( + region_name=config.region_name, profile_name=config.profile_name + ) + .refreshable_session() + .client("bedrock-runtime") + ) diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 557949c1b..bafc27003 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -14,6 +14,7 @@ import boto3 from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession from .config import BedrockSafetyConfig @@ -27,19 +28,27 @@ BEDROCK_SUPPORTED_SHIELDS = [ def _create_bedrock_client(config: BedrockSafetyConfig, name: str): - session_args = { - "aws_access_key_id": config.aws_access_key_id, - "aws_secret_access_key": config.aws_secret_access_key, - "aws_session_token": config.aws_session_token, - "region_name": config.region_name, - "profile_name": config.profile_name, - } + if config.aws_access_key_id and config.aws_secret_access_key: + session_args = { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + "aws_session_token": config.aws_session_token, + "region_name": config.region_name, + "profile_name": config.profile_name, + } + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} - # Remove None values - session_args = {k: v for k, v in session_args.items() if v is not None} - - boto3_session = boto3.session.Session(**session_args) - return boto3_session.client(name) + boto3_session = boto3.session.Session(**session_args) + return boto3_session.client(name) + else: + return ( + RefreshableBotoSession( + region_name=config.region_name, profile_name=config.profile_name + ) + .refreshable_session() + .client(name) + ) class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..22babb076 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="bedrock", - pip_packages=["boto3"], + pip_packages=["boto3", "pytz"], module="llama_stack.providers.adapters.inference.bedrock", config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig", ), diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 3fa62479a..79ac4ca2b 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]: api=Api.safety, adapter=AdapterSpec( adapter_type="bedrock", - pip_packages=["boto3"], + pip_packages=["boto3", "pytz"], module="llama_stack.providers.adapters.safety.bedrock", config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig", ), diff --git a/llama_stack/providers/utils/refreshable_boto_session.py b/llama_stack/providers/utils/refreshable_boto_session.py new file mode 100644 index 000000000..789d7c612 --- /dev/null +++ b/llama_stack/providers/utils/refreshable_boto_session.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from datetime import datetime +from time import time + +import pytz +from boto3 import Session +from botocore.credentials import RefreshableCredentials +from botocore.session import get_session + + +class RefreshableBotoSession: + """ + Boto Helper class which lets us create a refreshable session so that we can cache the client or resource. + + Usage + ----- + session = RefreshableBotoSession().refreshable_session() + + client = session.client("bedrock-runtime") # we now can cache this client object without worrying about expiring credentials + """ + + def __init__( + self, + region_name: str = None, + profile_name: str = None, + session_ttl: int = 30000, + ): + """ + Initialize `RefreshableBotoSession` + + Parameters + ---------- + region_name : str (optional) + Default region when creating a new connection. Will check AWS_REGION or AWS_DEFAULT_REGION env vars if not provided. + + profile_name : str (optional) + The name of a profile to use. Will check environment variables before using profile. + """ + # Check environment variables for region + self.region_name = ( + region_name + or os.environ.get("AWS_REGION") + or os.environ.get("AWS_DEFAULT_REGION") + ) + self.profile_name = profile_name + self.session_ttl = session_ttl + + def __get_session_credentials(self): + """ + Get session credentials from environment variables or session + """ + # Check for credentials in environment variables first + if all( + key in os.environ for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] + ): + expiry_time = ( + os.environ.get("EXPIRY_TIME") + or datetime.fromtimestamp(time() + self.session_ttl) + .replace(tzinfo=pytz.utc) + .isoformat() + ) + credentials = { + "access_key": os.environ["AWS_ACCESS_KEY_ID"], + "secret_key": os.environ["AWS_SECRET_ACCESS_KEY"], + "token": os.environ.get("AWS_SESSION_TOKEN"), # Optional + "expiry_time": expiry_time, + } + return credentials + + # Fall back to profile-based credentials + session = Session(region_name=self.region_name, profile_name=self.profile_name) + + session_credentials = session.get_credentials().get_frozen_credentials() + credentials = { + "access_key": session_credentials.access_key, + "secret_key": session_credentials.secret_key, + "token": session_credentials.token, + "expiry_time": datetime.fromtimestamp(time() + self.session_ttl) + .replace(tzinfo=pytz.utc) + .isoformat(), + } + + return credentials + + def refreshable_session(self) -> Session: + """ + Get refreshable boto3 session. + """ + # Get refreshable credentials + refreshable_credentials = RefreshableCredentials.create_from_metadata( + metadata=self.__get_session_credentials(), + refresh_using=self.__get_session_credentials, + method="sts-assume-role", + ) + + # attach refreshable credentials current session + session = get_session() + session._credentials = refreshable_credentials + session.set_config_variable("region", self.region_name) + autorefresh_session = Session(botocore_session=session) + + return autorefresh_session