feat(bedrock.py): Support using OIDC tokens.

This commit is contained in:
David Manouchehri 2024-05-07 15:01:46 +00:00
parent b0b87ba79b
commit 47449d19e8
2 changed files with 70 additions and 1 deletions

View file

@ -550,6 +550,7 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
@ -566,6 +567,7 @@ def init_bedrock_client(
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
@ -581,6 +583,7 @@ def init_bedrock_client(
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### SET REGION NAME
@ -619,7 +622,38 @@ def init_bedrock_client(
config = boto3.session.Config()
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
@ -752,6 +786,7 @@ def completion(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop("aws_bedrock_client", None)
@ -766,6 +801,7 @@ def completion(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
extra_headers=extra_headers,
timeout=timeout,
)
@ -1288,6 +1324,7 @@ def embedding(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -1295,6 +1332,7 @@ def embedding(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
@ -1377,6 +1415,7 @@ def image_generation(
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -1384,6 +1423,7 @@ def image_generation(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,