fix(main.py): get the region name from boto3 client if dynamic var not set

This commit is contained in:
Krrish Dholakia 2024-07-02 09:24:07 -07:00
parent 5aae2313f3
commit 79670ab82e
2 changed files with 64 additions and 0 deletions

View file

@ -2212,15 +2212,26 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "aws_bedrock_client" in optional_params:
verbose_logger.warning(
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
)
# Extract credentials for legacy boto3 client and pass thru to httpx
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
if creds.access_key:
optional_params["aws_access_key_id"] = creds.access_key
if creds.secret_key:
optional_params["aws_secret_access_key"] = creds.secret_key
if creds.token:
optional_params["aws_session_token"] = creds.token
if (
"aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None
):
optional_params["aws_region_name"] = (
aws_bedrock_client.meta.region_name
)
if model in litellm.BEDROCK_CONVERSE_MODELS:
response = bedrock_converse_chat_completion.completion(

View file

@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
mock_client_post.assert_called_once()
def test_completion_bedrock_external_client_region():
print("\ncalling bedrock claude external client auth")
import os
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = "us-east-1"
os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
client = HTTPHandler()
try:
import boto3
litellm.set_verbose = True
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
)
with patch.object(client, "post", new=Mock()) as mock_client_post:
try:
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_bedrock_client=bedrock,
client=client,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pass
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
mock_client_post.assert_called_once()
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")