refereshable boto credentials

This commit is contained in:
Dinesh Yeduguru 2024-11-06 06:59:55 -08:00
parent 7d28dc380e
commit 6697ca3d3a
5 changed files with 171 additions and 44 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
BEDROCK_SUPPORTED_MODELS = {
@ -441,38 +442,47 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
if config.aws_access_key_id and config.aws_secret_access_key:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
boto3_config = Config(**config_args)
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
)
.refreshable_session()
.client("bedrock-runtime")
)

View file

@ -14,6 +14,7 @@ import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.refreshable_boto_session import RefreshableBotoSession
from .config import BedrockSafetyConfig
@ -27,19 +28,27 @@ BEDROCK_SUPPORTED_SHIELDS = [
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
if config.aws_access_key_id and config.aws_secret_access_key:
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
else:
return (
RefreshableBotoSession(
region_name=config.region_name, profile_name=config.profile_name
)
.refreshable_session()
.client(name)
)
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -124,7 +124,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
pip_packages=["boto3", "pytz"],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),

View file

@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
pip_packages=["boto3", "pytz"],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
),

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from datetime import datetime
from time import time
import pytz
from boto3 import Session
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
class RefreshableBotoSession:
"""
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
Usage
-----
session = RefreshableBotoSession().refreshable_session()
client = session.client("bedrock-runtime") # we now can cache this client object without worrying about expiring credentials
"""
def __init__(
self,
region_name: str = None,
profile_name: str = None,
session_ttl: int = 30000,
):
"""
Initialize `RefreshableBotoSession`
Parameters
----------
region_name : str (optional)
Default region when creating a new connection. Will check AWS_REGION or AWS_DEFAULT_REGION env vars if not provided.
profile_name : str (optional)
The name of a profile to use. Will check environment variables before using profile.
"""
# Check environment variables for region
self.region_name = (
region_name
or os.environ.get("AWS_REGION")
or os.environ.get("AWS_DEFAULT_REGION")
)
self.profile_name = profile_name
self.session_ttl = session_ttl
def __get_session_credentials(self):
"""
Get session credentials from environment variables or session
"""
# Check for credentials in environment variables first
if all(
key in os.environ for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]
):
expiry_time = (
os.environ.get("EXPIRY_TIME")
or datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat()
)
credentials = {
"access_key": os.environ["AWS_ACCESS_KEY_ID"],
"secret_key": os.environ["AWS_SECRET_ACCESS_KEY"],
"token": os.environ.get("AWS_SESSION_TOKEN"), # Optional
"expiry_time": expiry_time,
}
return credentials
# Fall back to profile-based credentials
session = Session(region_name=self.region_name, profile_name=self.profile_name)
session_credentials = session.get_credentials().get_frozen_credentials()
credentials = {
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.fromtimestamp(time() + self.session_ttl)
.replace(tzinfo=pytz.utc)
.isoformat(),
}
return credentials
def refreshable_session(self) -> Session:
"""
Get refreshable boto3 session.
"""
# Get refreshable credentials
refreshable_credentials = RefreshableCredentials.create_from_metadata(
metadata=self.__get_session_credentials(),
refresh_using=self.__get_session_credentials,
method="sts-assume-role",
)
# attach refreshable credentials current session
session = get_session()
session._credentials = refreshable_credentials
session.set_config_variable("region", self.region_name)
autorefresh_session = Session(botocore_session=session)
return autorefresh_session