forked from phoenix/litellm-mirror
fix(bedrock.py): add support for sts based boto3 initialization
https://github.com/BerriAI/litellm/issues/1476
This commit is contained in:
parent
d63147f342
commit
8e9dc09955
5 changed files with 106 additions and 4 deletions
|
@ -21,6 +21,7 @@ from openai import (
|
|||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
PermissionDeniedError,
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
@ -82,6 +83,17 @@ class Timeout(APITimeoutError): # type: ignore
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
self.status_code = 403
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class RateLimitError(RateLimitError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
self.status_code = 429
|
||||
|
|
|
@ -288,6 +288,8 @@ def init_bedrock_client(
|
|||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
@ -300,6 +302,8 @@ def init_bedrock_client(
|
|||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_role_name,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
|
@ -312,7 +316,11 @@ def init_bedrock_client(
|
|||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_role_name,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
|
@ -338,7 +346,28 @@ def init_bedrock_client(
|
|||
|
||||
import boto3
|
||||
|
||||
if aws_access_key_id != None:
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
|
@ -419,6 +448,8 @@ def completion(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
@ -433,6 +464,8 @@ def completion(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
||||
model = model
|
||||
|
@ -738,6 +771,8 @@ def embedding(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
|
@ -748,6 +783,8 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
if type(input) == str:
|
||||
embeddings = [
|
||||
|
|
|
@ -2527,9 +2527,7 @@ def embedding(
|
|||
)
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model,
|
||||
original_exception=e,
|
||||
custom_llm_provider="azure" if azure == True else None,
|
||||
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -150,6 +150,52 @@ def test_completion_bedrock_claude_external_client_auth():
|
|||
# test_completion_bedrock_claude_external_client_auth()
|
||||
|
||||
|
||||
def test_completion_bedrock_claude_sts_client_auth():
|
||||
print("\ncalling bedrock claude external client auth")
|
||||
import os
|
||||
|
||||
aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"]
|
||||
aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"]
|
||||
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
response = completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
|
||||
response = embedding(
|
||||
model="cohere.embed-multilingual-v3",
|
||||
input=["hello world"],
|
||||
aws_region_name="us-east-1",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name="my-test-session",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_bedrock_claude_sts_client_auth()
|
||||
|
||||
|
||||
def test_provisioned_throughput():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -60,6 +60,7 @@ from .exceptions import (
|
|||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
OpenAIError,
|
||||
PermissionDeniedError,
|
||||
ContextWindowExceededError,
|
||||
ContentPolicyViolationError,
|
||||
Timeout,
|
||||
|
@ -5924,6 +5925,14 @@ def exception_type(
|
|||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
if "AccessDeniedException" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise PermissionDeniedError(
|
||||
message=f"BedrockException PermissionDeniedError - {error_str}",
|
||||
model=model,
|
||||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
if (
|
||||
"throttlingException" in error_str
|
||||
or "ThrottlingException" in error_str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue