fix(bedrock.py): add support for sts based boto3 initialization

https://github.com/BerriAI/litellm/issues/1476
This commit is contained in:
Krrish Dholakia 2024-01-17 12:08:39 -08:00
parent d63147f342
commit 8e9dc09955
5 changed files with 106 additions and 4 deletions

View file

@ -21,6 +21,7 @@ from openai import (
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError,
PermissionDeniedError,
)
import httpx
@ -82,6 +83,17 @@ class Timeout(APITimeoutError): # type: ignore
) # Call the base class constructor with the parameters it needs
class PermissionDeniedError(PermissionDeniedError): # type:ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 403
self.message = message
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429

View file

@ -288,6 +288,8 @@ def init_bedrock_client(
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
@ -300,6 +302,8 @@ def init_bedrock_client(
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
aws_session_name,
aws_role_name,
]
# Iterate over parameters and update if needed
@ -312,7 +316,11 @@ def init_bedrock_client(
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
aws_session_name,
aws_role_name,
) = params_to_check
### SET REGION NAME
if region_name:
pass
elif aws_region_name:
@ -338,7 +346,28 @@ def init_bedrock_client(
import boto3
if aws_access_key_id != None:
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -419,6 +448,8 @@ def completion(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
@ -433,6 +464,8 @@ def completion(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
model = model
@ -738,6 +771,8 @@ def embedding(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
@ -748,6 +783,8 @@ def embedding(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
if type(input) == str:
embeddings = [

View file

@ -2527,9 +2527,7 @@ def embedding(
)
## Map to OpenAI Exception
raise exception_type(
model=model,
original_exception=e,
custom_llm_provider="azure" if azure == True else None,
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
)

View file

@ -150,6 +150,52 @@ def test_completion_bedrock_claude_external_client_auth():
# test_completion_bedrock_claude_external_client_auth()
def test_completion_bedrock_claude_sts_client_auth():
print("\ncalling bedrock claude external client auth")
import os
aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
try:
import boto3
litellm.set_verbose = True
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
response = embedding(
model="cohere.embed-multilingual-v3",
input=["hello world"],
aws_region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_claude_sts_client_auth()
def test_provisioned_throughput():
try:
litellm.set_verbose = True

View file

@ -60,6 +60,7 @@ from .exceptions import (
RateLimitError,
ServiceUnavailableError,
OpenAIError,
PermissionDeniedError,
ContextWindowExceededError,
ContentPolicyViolationError,
Timeout,
@ -5924,6 +5925,14 @@ def exception_type(
llm_provider="bedrock",
response=original_exception.response,
)
if "AccessDeniedException" in error_str:
exception_mapping_worked = True
raise PermissionDeniedError(
message=f"BedrockException PermissionDeniedError - {error_str}",
model=model,
llm_provider="bedrock",
response=original_exception.response,
)
if (
"throttlingException" in error_str
or "ThrottlingException" in error_str