mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #3712 from Manouchehri/oidc-bedrock-httpx-caching-part-1
Add IAM cred caching for OIDC flow
This commit is contained in:
commit
2d701e6e63
2 changed files with 121 additions and 57 deletions
|
@ -53,7 +53,9 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionDeltaChunk,
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
|
||||
iam_cache = DualCache()
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
"""
|
||||
|
@ -325,38 +327,53 @@ class BedrockLLM(BaseLLM):
|
|||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||
iam_creds_cache_key = json.dumps({
|
||||
"aws_web_identity_token": aws_web_identity_token,
|
||||
"aws_role_name": aws_role_name,
|
||||
"aws_session_name": aws_session_name,
|
||||
"aws_region_name": aws_region_name,
|
||||
})
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||
if iam_creds_dict is None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
region_name=aws_region_name,
|
||||
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
iam_creds_dict = {
|
||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||
"region_name": aws_region_name,
|
||||
}
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
|
||||
|
||||
return session.get_credentials()
|
||||
session = boto3.Session(**iam_creds_dict)
|
||||
|
||||
iam_creds = session.get_credentials()
|
||||
|
||||
return iam_creds
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
|
@ -1416,38 +1433,53 @@ class BedrockConverseLLM(BaseLLM):
|
|||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||
iam_creds_cache_key = json.dumps({
|
||||
"aws_web_identity_token": aws_web_identity_token,
|
||||
"aws_role_name": aws_role_name,
|
||||
"aws_session_name": aws_session_name,
|
||||
"aws_region_name": aws_region_name,
|
||||
})
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||
if iam_creds_dict is None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
region_name=aws_region_name,
|
||||
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
iam_creds_dict = {
|
||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||
"region_name": aws_region_name,
|
||||
}
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
|
||||
|
||||
return session.get_credentials()
|
||||
session = boto3.Session(**iam_creds_dict)
|
||||
|
||||
iam_creds = session.get_credentials()
|
||||
|
||||
return iam_creds
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
|
|
|
@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
|||
aws_web_identity_token = "oidc/circleci_v2/"
|
||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
response = completion(
|
||||
response_1 = completion(
|
||||
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
|
@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
print(response_1)
|
||||
assert len(response_1.choices) > 0
|
||||
assert len(response_1.choices[0].message.content) > 0
|
||||
|
||||
# This second call is to verify that the cache isn't breaking anything
|
||||
response_2 = completion(
|
||||
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.2,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
print(response_2)
|
||||
assert len(response_2.choices) > 0
|
||||
assert len(response_2.choices[0].message.content) > 0
|
||||
|
||||
# This third call is to verify that the cache isn't used for a different region
|
||||
response_3 = completion(
|
||||
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=6,
|
||||
temperature=0.3,
|
||||
aws_region_name="us-east-1",
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
print(response_3)
|
||||
assert len(response_3.choices) > 0
|
||||
assert len(response_3.choices[0].message.content) > 0
|
||||
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
|
|||
aws_web_identity_token = "oidc/circleci_v2/"
|
||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
|
||||
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
||||
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue