Merge pull request #3712 from Manouchehri/oidc-bedrock-httpx-caching-part-1

Add IAM cred caching for OIDC flow
This commit is contained in:
Krish Dholakia 2024-06-12 12:44:58 -07:00 committed by GitHub
commit 2d701e6e63
2 changed files with 121 additions and 57 deletions

View file

@ -53,7 +53,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
)
from litellm.caching import DualCache
iam_cache = DualCache()
class AmazonCohereChatConfig:
"""
@ -325,38 +327,53 @@ class BedrockLLM(BaseLLM):
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
return session.get_credentials()
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
@ -1416,38 +1433,53 @@ class BedrockConverseLLM(BaseLLM):
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
iam_creds_cache_key = json.dumps({
"aws_web_identity_token": aws_web_identity_token,
"aws_role_name": aws_role_name,
"aws_session_name": aws_session_name,
"aws_region_name": aws_region_name,
})
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
if iam_creds_dict is None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
iam_creds_dict = {
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
"aws_session_token": sts_response["Credentials"]["SessionToken"],
"region_name": aws_region_name,
}
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
return session.get_credentials()
session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",

View file

@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
try:
litellm.set_verbose = True
response = completion(
response_1 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
print(response_1)
assert len(response_1.choices) > 0
assert len(response_1.choices[0].message.content) > 0
# This second call is to verify that the cache isn't breaking anything
response_2 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=5,
temperature=0.2,
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_2)
assert len(response_2.choices) > 0
assert len(response_2.choices[0].message.content) > 0
# This third call is to verify that the cache isn't used for a different region
response_3 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_region_name="us-east-1",
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_3)
assert len(response_3.choices) > 0
assert len(response_3.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
try: