From 8e9dc0995522d6f44f4f0030f2be92408189360a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Jan 2024 12:08:39 -0800 Subject: [PATCH] fix(bedrock.py): add support for sts based boto3 initialization https://github.com/BerriAI/litellm/issues/1476 --- litellm/exceptions.py | 12 +++++++ litellm/llms/bedrock.py | 39 +++++++++++++++++++- litellm/main.py | 4 +-- litellm/tests/test_bedrock_completion.py | 46 ++++++++++++++++++++++++ litellm/utils.py | 9 +++++ 5 files changed, 106 insertions(+), 4 deletions(-) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 4f9629e71..09b375811 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -21,6 +21,7 @@ from openai import ( APIConnectionError, APIResponseValidationError, UnprocessableEntityError, + PermissionDeniedError, ) import httpx @@ -82,6 +83,17 @@ class Timeout(APITimeoutError): # type: ignore ) # Call the base class constructor with the parameters it needs +class PermissionDeniedError(PermissionDeniedError): # type:ignore + def __init__(self, message, llm_provider, model, response: httpx.Response): + self.status_code = 403 + self.message = message + self.llm_provider = llm_provider + self.model = model + super().__init__( + self.message, response=response, body=None + ) # Call the base class constructor with the parameters it needs + + class RateLimitError(RateLimitError): # type: ignore def __init__(self, message, llm_provider, model, response: httpx.Response): self.status_code = 429 diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 85b820c9e..4c36137da 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -288,6 +288,8 @@ def init_bedrock_client( aws_secret_access_key: Optional[str] = None, aws_region_name: Optional[str] = None, aws_bedrock_runtime_endpoint: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_role_name: Optional[str] = None, ): # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) @@ -300,6 +302,8 @@ def init_bedrock_client( aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint, + aws_session_name, + aws_role_name, ] # Iterate over parameters and update if needed @@ -312,7 +316,11 @@ def init_bedrock_client( aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint, + aws_session_name, + aws_role_name, ) = params_to_check + + ### SET REGION NAME if region_name: pass elif aws_region_name: @@ -338,7 +346,28 @@ def init_bedrock_client( import boto3 - if aws_access_key_id != None: + ### CHECK STS ### + if aws_role_name is not None and aws_session_name is not None: + # use sts if role name passed in + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + ) + elif aws_access_key_id is not None: # uses auth params passed to completion # aws_access_key_id is not None, assume user is trying to auth using litellm.completion @@ -419,6 +448,8 @@ def completion( aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) @@ -433,6 +464,8 @@ def completion( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, ) model = model @@ -738,6 +771,8 @@ def embedding( aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) @@ -748,6 +783,8 @@ def embedding( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, ) if type(input) == str: embeddings = [ diff --git a/litellm/main.py b/litellm/main.py index 2d9b4dc32..632085816 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2527,9 +2527,7 @@ def embedding( ) ## Map to OpenAI Exception raise exception_type( - model=model, - original_exception=e, - custom_llm_provider="azure" if azure == True else None, + model=model, original_exception=e, custom_llm_provider=custom_llm_provider ) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index d96fdfa5a..07b0cb288 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -150,6 +150,52 @@ def test_completion_bedrock_claude_external_client_auth(): # test_completion_bedrock_claude_external_client_auth() +def test_completion_bedrock_claude_sts_client_auth(): + print("\ncalling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"] + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] + + try: + import boto3 + + litellm.set_verbose = True + + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + + response = embedding( + model="cohere.embed-multilingual-v3", + input=["hello world"], + aws_region_name="us-east-1", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + # Add any assertions here to check the response + print(response) + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +test_completion_bedrock_claude_sts_client_auth() + + def test_provisioned_throughput(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index 3a01628a4..45bdca7fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -60,6 +60,7 @@ from .exceptions import ( RateLimitError, ServiceUnavailableError, OpenAIError, + PermissionDeniedError, ContextWindowExceededError, ContentPolicyViolationError, Timeout, @@ -5924,6 +5925,14 @@ def exception_type( llm_provider="bedrock", response=original_exception.response, ) + if "AccessDeniedException" in error_str: + exception_mapping_worked = True + raise PermissionDeniedError( + message=f"BedrockException PermissionDeniedError - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) if ( "throttlingException" in error_str or "ThrottlingException" in error_str