diff --git a/litellm/main.py b/litellm/main.py index c48c242ce..55ac01935 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2212,15 +2212,26 @@ def completion( custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict if "aws_bedrock_client" in optional_params: + verbose_logger.warning( + "'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication." + ) # Extract credentials for legacy boto3 client and pass thru to httpx aws_bedrock_client = optional_params.pop("aws_bedrock_client") creds = aws_bedrock_client._get_credentials().get_frozen_credentials() + if creds.access_key: optional_params["aws_access_key_id"] = creds.access_key if creds.secret_key: optional_params["aws_secret_access_key"] = creds.secret_key if creds.token: optional_params["aws_session_token"] = creds.token + if ( + "aws_region_name" not in optional_params + or optional_params["aws_region_name"] is None + ): + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) if model in litellm.BEDROCK_CONVERSE_MODELS: response = bedrock_converse_chat_completion.completion( diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 6e39c30b3..fb4ba7556 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -856,3 +856,56 @@ async def test_bedrock_custom_prompt_template(): prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"] assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>" mock_client_post.assert_called_once() + + +def test_completion_bedrock_external_client_region(): + print("\ncalling bedrock claude external client auth") + import os + + aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_region_name = "us-east-1" + + os.environ.pop("AWS_ACCESS_KEY_ID", None) + os.environ.pop("AWS_SECRET_ACCESS_KEY", None) + + client = HTTPHandler() + + try: + import boto3 + + litellm.set_verbose = True + + bedrock = boto3.client( + service_name="bedrock-runtime", + region_name=aws_region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com", + ) + with patch.object(client, "post", new=Mock()) as mock_client_post: + try: + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_bedrock_client=bedrock, + client=client, + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pass + + print(f"mock_client_post.call_args: {mock_client_post.call_args}") + assert "us-east-1" in mock_client_post.call_args.kwargs["url"] + + mock_client_post.assert_called_once() + + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}")