From 47449d19e8b2dfbccda12b41c9ed249f0ecaa8f5 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 15:01:46 +0000 Subject: [PATCH] feat(bedrock.py): Support using OIDC tokens. --- litellm/llms/bedrock.py | 42 +++++++++++++++++++++++- litellm/tests/test_bedrock_completion.py | 29 ++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 2f26ae4a9a..fe0aca44e7 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -550,6 +550,7 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, extra_headers: Optional[dict] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, ): @@ -566,6 +567,7 @@ def init_bedrock_client( aws_session_name, aws_profile_name, aws_role_name, + aws_web_identity_token, ] # Iterate over parameters and update if needed @@ -581,6 +583,7 @@ def init_bedrock_client( aws_session_name, aws_profile_name, aws_role_name, + aws_web_identity_token, ) = params_to_check ### SET REGION NAME @@ -619,7 +622,38 @@ def init_bedrock_client( config = boto3.session.Config() ### CHECK STS ### - if aws_role_name is not None and aws_session_name is not None: + if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client( + "sts" + ) + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + ) + elif aws_role_name is not None and aws_session_name is not None: # use sts if role name passed in sts_client = boto3.client( "sts", @@ -752,6 +786,7 @@ def completion( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = optional_params.pop("aws_bedrock_client", None) @@ -766,6 +801,7 @@ def completion( aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, + aws_web_identity_token=aws_web_identity_token, extra_headers=extra_headers, timeout=timeout, ) @@ -1288,6 +1324,7 @@ def embedding( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = init_bedrock_client( @@ -1295,6 +1332,7 @@ def embedding( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_web_identity_token=aws_web_identity_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, ) @@ -1377,6 +1415,7 @@ def image_generation( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = init_bedrock_client( @@ -1384,6 +1423,7 @@ def image_generation( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_web_identity_token=aws_web_identity_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, timeout=timeout, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 3f5c831d73..ef6774fd2f 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth(): # test_completion_bedrock_claude_sts_client_auth() +@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set") +def test_completion_bedrock_claude_sts_oidc_auth(): + print("\ncalling bedrock claude with oidc auth") + import os + + aws_web_identity_token = "oidc/circleci_v2/" + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] + + try: + litellm.set_verbose = True + + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_region_name=aws_region_name, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + # Add any assertions here to check the response + print(response) + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_bedrock_extra_headers(): try: