Merge pull request #3712 from Manouchehri/oidc-bedrock-httpx-caching-part-1

Add IAM cred caching for OIDC flow
This commit is contained in:
Krish Dholakia 2024-06-12 12:44:58 -07:00 committed by GitHub
commit 2d701e6e63
2 changed files with 121 additions and 57 deletions

View file

@ -53,7 +53,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
)
from litellm.caching import DualCache
iam_cache = DualCache()
class AmazonCohereChatConfig:
"""
@ -325,38 +327,53 @@ class BedrockLLM(BaseLLM):
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
return session.get_credentials()
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
@ -1416,38 +1433,53 @@ class BedrockConverseLLM(BaseLLM):
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
return session.get_credentials()
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",