Merge pull request #3712 from Manouchehri/oidc-bedrock-httpx-caching-part-1

Add IAM cred caching for OIDC flow
This commit is contained in:
Krish Dholakia 2024-06-12 12:44:58 -07:00 committed by GitHub
commit 2d701e6e63
2 changed files with 121 additions and 57 deletions

View file

@ -53,7 +53,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk, ChatCompletionDeltaChunk,
) )
from litellm.caching import DualCache
iam_cache = DualCache()
class AmazonCohereChatConfig: class AmazonCohereChatConfig:
""" """
@ -325,38 +327,53 @@ class BedrockLLM(BaseLLM):
) = params_to_check ) = params_to_check
### CHECK STS ### ### CHECK STS ###
if ( if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
aws_web_identity_token is not None iam_creds_cache_key = json.dumps({
and aws_role_name is not None "aws_web_identity_token": aws_web_identity_token,
and aws_session_name is not None "aws_role_name": aws_role_name,
): "aws_session_name": aws_session_name,
oidc_token = get_secret(aws_web_identity_token) "aws_region_name": aws_region_name,
})
if oidc_token is None: iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
raise BedrockError( if iam_creds_dict is None:
message="OIDC token could not be retrieved from secret manager.", oidc_token = get_secret(aws_web_identity_token)
status_code=401,
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
) )
sts_client = boto3.client("sts") # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html iam_creds_dict = {
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
sts_response = sts_client.assume_role_with_web_identity( "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
RoleArn=aws_role_name, "aws_session_token": sts_response["Credentials"]["SessionToken"],
RoleSessionName=aws_session_name, "region_name": aws_region_name,
WebIdentityToken=oidc_token, }
DurationSeconds=3600,
)
session = boto3.Session( iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials() session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None: elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -1416,38 +1433,53 @@ class BedrockConverseLLM(BaseLLM):
) = params_to_check ) = params_to_check
### CHECK STS ### ### CHECK STS ###
if ( if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
aws_web_identity_token is not None iam_creds_cache_key = json.dumps({
and aws_role_name is not None "aws_web_identity_token": aws_web_identity_token,
and aws_session_name is not None "aws_role_name": aws_role_name,
): "aws_session_name": aws_session_name,
oidc_token = get_secret(aws_web_identity_token) "aws_region_name": aws_region_name,
})
if oidc_token is None: iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
raise BedrockError( if iam_creds_dict is None:
message="OIDC token could not be retrieved from secret manager.", oidc_token = get_secret(aws_web_identity_token)
status_code=401,
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts",
region_name=aws_region_name,
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
) )
sts_client = boto3.client("sts") # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html iam_creds_dict = {
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
sts_response = sts_client.assume_role_with_web_identity( "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
RoleArn=aws_role_name, "aws_session_token": sts_response["Credentials"]["SessionToken"],
RoleSessionName=aws_session_name, "region_name": aws_region_name,
WebIdentityToken=oidc_token, }
DurationSeconds=3600,
)
session = boto3.Session( iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials() session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials()
return iam_creds
elif aws_role_name is not None and aws_session_name is not None: elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",

View file

@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/" aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"] aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = completion( response_1 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth():
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name="my-test-session", aws_session_name="my-test-session",
) )
# Add any assertions here to check the response print(response_1)
print(response) assert len(response_1.choices) > 0
assert len(response_1.choices[0].message.content) > 0
# This second call is to verify that the cache isn't breaking anything
response_2 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=5,
temperature=0.2,
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_2)
assert len(response_2.choices) > 0
assert len(response_2.choices[0].message.content) > 0
# This third call is to verify that the cache isn't used for a different region
response_3 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_region_name="us-east-1",
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
print(response_3)
assert len(response_3.choices) > 0
assert len(response_3.choices[0].message.content) > 0
except RateLimitError: except RateLimitError:
pass pass
except Exception as e: except Exception as e:
@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
aws_web_identity_token = "oidc/circleci_v2/" aws_web_identity_token = "oidc/circleci_v2/"
aws_region_name = os.environ["AWS_REGION_NAME"] aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] # aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci" aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
try: try: