mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
Merge pull request #3712 from Manouchehri/oidc-bedrock-httpx-caching-part-1
Add IAM cred caching for OIDC flow
This commit is contained in:
commit
2d701e6e63
2 changed files with 121 additions and 57 deletions
|
@ -53,7 +53,9 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionDeltaChunk,
|
ChatCompletionDeltaChunk,
|
||||||
)
|
)
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
iam_cache = DualCache()
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -325,38 +327,53 @@ class BedrockLLM(BaseLLM):
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if (
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
aws_web_identity_token is not None
|
iam_creds_cache_key = json.dumps({
|
||||||
and aws_role_name is not None
|
"aws_web_identity_token": aws_web_identity_token,
|
||||||
and aws_session_name is not None
|
"aws_role_name": aws_role_name,
|
||||||
):
|
"aws_session_name": aws_session_name,
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
"aws_region_name": aws_region_name,
|
||||||
|
})
|
||||||
|
|
||||||
if oidc_token is None:
|
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||||
raise BedrockError(
|
if iam_creds_dict is None:
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
status_code=401,
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
sts_client = boto3.client("sts")
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
iam_creds_dict = {
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||||
RoleArn=aws_role_name,
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
RoleSessionName=aws_session_name,
|
"region_name": aws_region_name,
|
||||||
WebIdentityToken=oidc_token,
|
}
|
||||||
DurationSeconds=3600,
|
|
||||||
)
|
|
||||||
|
|
||||||
session = boto3.Session(
|
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
|
||||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
||||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
||||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
||||||
region_name=aws_region_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
|
iam_creds = session.get_credentials()
|
||||||
|
|
||||||
|
return iam_creds
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -1416,38 +1433,53 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if (
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
aws_web_identity_token is not None
|
iam_creds_cache_key = json.dumps({
|
||||||
and aws_role_name is not None
|
"aws_web_identity_token": aws_web_identity_token,
|
||||||
and aws_session_name is not None
|
"aws_role_name": aws_role_name,
|
||||||
):
|
"aws_session_name": aws_session_name,
|
||||||
oidc_token = get_secret(aws_web_identity_token)
|
"aws_region_name": aws_region_name,
|
||||||
|
})
|
||||||
|
|
||||||
if oidc_token is None:
|
iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key)
|
||||||
raise BedrockError(
|
if iam_creds_dict is None:
|
||||||
message="OIDC token could not be retrieved from secret manager.",
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
status_code=401,
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
endpoint_url=f"https://sts.{aws_region_name}.amazonaws.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
sts_client = boto3.client("sts")
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
iam_creds_dict = {
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||||
sts_response = sts_client.assume_role_with_web_identity(
|
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||||
RoleArn=aws_role_name,
|
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||||
RoleSessionName=aws_session_name,
|
"region_name": aws_region_name,
|
||||||
WebIdentityToken=oidc_token,
|
}
|
||||||
DurationSeconds=3600,
|
|
||||||
)
|
|
||||||
|
|
||||||
session = boto3.Session(
|
iam_cache.set_cache(key=iam_creds_cache_key, value=json.dumps(iam_creds_dict), ttl=3600 - 60)
|
||||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
||||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
||||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
||||||
region_name=aws_region_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return session.get_credentials()
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
|
iam_creds = session.get_credentials()
|
||||||
|
|
||||||
|
return iam_creds
|
||||||
elif aws_role_name is not None and aws_session_name is not None:
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
|
|
@ -220,13 +220,13 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
aws_web_identity_token = "oidc/circleci_v2/"
|
aws_web_identity_token = "oidc/circleci_v2/"
|
||||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
|
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
response = completion(
|
response_1 = completion(
|
||||||
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -236,8 +236,40 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name="my-test-session",
|
aws_session_name="my-test-session",
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
print(response_1)
|
||||||
print(response)
|
assert len(response_1.choices) > 0
|
||||||
|
assert len(response_1.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This second call is to verify that the cache isn't breaking anything
|
||||||
|
response_2 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.2,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
print(response_2)
|
||||||
|
assert len(response_2.choices) > 0
|
||||||
|
assert len(response_2.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This third call is to verify that the cache isn't used for a different region
|
||||||
|
response_3 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=6,
|
||||||
|
temperature=0.3,
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
print(response_3)
|
||||||
|
assert len(response_3.choices) > 0
|
||||||
|
assert len(response_3.choices[0].message.content) > 0
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -255,7 +287,7 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
|
||||||
aws_web_identity_token = "oidc/circleci_v2/"
|
aws_web_identity_token = "oidc/circleci_v2/"
|
||||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
# TODO: This is using David's IAM role, we should use Litellm's IAM role eventually
|
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||||
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue