forked from phoenix/litellm-mirror
fix(bedrock.py): add support for sts based boto3 initialization
https://github.com/BerriAI/litellm/issues/1476
This commit is contained in:
parent
d63147f342
commit
8e9dc09955
5 changed files with 106 additions and 4 deletions
|
@ -21,6 +21,7 @@ from openai import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIResponseValidationError,
|
APIResponseValidationError,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
|
PermissionDeniedError,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -82,6 +83,17 @@ class Timeout(APITimeoutError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
||||||
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
|
self.status_code = 403
|
||||||
|
self.message = message
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
self.model = model
|
||||||
|
super().__init__(
|
||||||
|
self.message, response=response, body=None
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(RateLimitError): # type: ignore
|
class RateLimitError(RateLimitError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 429
|
self.status_code = 429
|
||||||
|
|
|
@ -288,6 +288,8 @@ def init_bedrock_client(
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
aws_region_name: Optional[str] = None,
|
aws_region_name: Optional[str] = None,
|
||||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
@ -300,6 +302,8 @@ def init_bedrock_client(
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint,
|
||||||
|
aws_session_name,
|
||||||
|
aws_role_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
|
@ -312,7 +316,11 @@ def init_bedrock_client(
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint,
|
||||||
|
aws_session_name,
|
||||||
|
aws_role_name,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
|
### SET REGION NAME
|
||||||
if region_name:
|
if region_name:
|
||||||
pass
|
pass
|
||||||
elif aws_region_name:
|
elif aws_region_name:
|
||||||
|
@ -338,7 +346,28 @@ def init_bedrock_client(
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
if aws_access_key_id != None:
|
### CHECK STS ###
|
||||||
|
if aws_role_name is not None and aws_session_name is not None:
|
||||||
|
# use sts if role name passed in
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
)
|
||||||
|
elif aws_access_key_id is not None:
|
||||||
# uses auth params passed to completion
|
# uses auth params passed to completion
|
||||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||||
|
|
||||||
|
@ -419,6 +448,8 @@ def completion(
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
@ -433,6 +464,8 @@ def completion(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model
|
model = model
|
||||||
|
@ -738,6 +771,8 @@ def embedding(
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
@ -748,6 +783,8 @@ def embedding(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
if type(input) == str:
|
if type(input) == str:
|
||||||
embeddings = [
|
embeddings = [
|
||||||
|
|
|
@ -2527,9 +2527,7 @@ def embedding(
|
||||||
)
|
)
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=model,
|
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
|
||||||
original_exception=e,
|
|
||||||
custom_llm_provider="azure" if azure == True else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -150,6 +150,52 @@ def test_completion_bedrock_claude_external_client_auth():
|
||||||
# test_completion_bedrock_claude_external_client_auth()
|
# test_completion_bedrock_claude_external_client_auth()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_bedrock_claude_sts_client_auth():
|
||||||
|
print("\ncalling bedrock claude external client auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"]
|
||||||
|
aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"]
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="cohere.embed-multilingual-v3",
|
||||||
|
input=["hello world"],
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
|
||||||
def test_provisioned_throughput():
|
def test_provisioned_throughput():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -60,6 +60,7 @@ from .exceptions import (
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
|
PermissionDeniedError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
Timeout,
|
Timeout,
|
||||||
|
@ -5924,6 +5925,14 @@ def exception_type(
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
if "AccessDeniedException" in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise PermissionDeniedError(
|
||||||
|
message=f"BedrockException PermissionDeniedError - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
"throttlingException" in error_str
|
"throttlingException" in error_str
|
||||||
or "ThrottlingException" in error_str
|
or "ThrottlingException" in error_str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue