diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index a316eb7ea6..5414429d4c 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -129,7 +129,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them - extra_headers = optional_params.get("extra_headers", None) aws_secret_access_key = optional_params.get("aws_secret_access_key", None) aws_access_key_id = optional_params.get("aws_access_key_id", None) aws_session_token = optional_params.get("aws_session_token", None) @@ -155,9 +154,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): ) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) - headers = {"Content-Type": "application/json"} - if extra_headers is not None: - headers = {"Content-Type": "application/json", **extra_headers} + if headers is not None: + headers = {"Content-Type": "application/json", **headers} + else: + headers = {"Content-Type": "application/json"} request = AWSRequest( method="POST", @@ -166,12 +166,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): headers=headers, ) sigv4.add_auth(request) - if ( - extra_headers is not None and "Authorization" in extra_headers - ): # prevent sigv4 from overwriting the auth header - request.headers["Authorization"] = extra_headers["Authorization"] - return dict(request.headers) + request_headers_dict = dict(request.headers) + if ( + headers is not None and "Authorization" in headers + ): # prevent sigv4 from overwriting the auth header + request_headers_dict["Authorization"] = headers["Authorization"] + return request_headers_dict def transform_request( self, @@ -443,7 +444,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: - return {} + return headers def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 0459854c4e..9d67fd1a85 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -873,7 +873,9 @@ class BaseLLMHTTPHandler: elif isinstance(audio_file, bytes): # Assume it's already binary data binary_data = audio_file - elif isinstance(audio_file, io.BufferedReader) or isinstance(audio_file, io.BytesIO): + elif isinstance(audio_file, io.BufferedReader) or isinstance( + audio_file, io.BytesIO + ): # Handle file-like objects binary_data = audio_file.read() diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index cc8cc163d4..e2948789fc 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -956,7 +956,7 @@ def test_bedrock_ptu(): @pytest.mark.asyncio -async def test_bedrock_extra_headers(): +async def test_bedrock_custom_api_base(): """ Check if a url with 'modelId' passed in, is created correctly @@ -994,6 +994,44 @@ async def test_bedrock_extra_headers(): mock_client_post.assert_called_once() +@pytest.mark.parametrize( + "model", + [ + "anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0", + ], +) +@pytest.mark.asyncio +async def test_bedrock_extra_headers(model): + """ + Relevant Issue: https://github.com/BerriAI/litellm/issues/9106 + """ + client = AsyncHTTPHandler() + + with patch.object(client, "post", new=AsyncMock()) as mock_client_post: + litellm.set_verbose = True + from openai.types.chat import ChatCompletion + + try: + response = await litellm.acompletion( + model=model, + messages=[{"role": "user", "content": "What's AWS?"}], + client=client, + extra_headers={"test": "hello world", "Authorization": "my-test-key"}, + ) + except Exception as e: + print(f"error: {e}") + + print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}") + assert "test" in mock_client_post.call_args.kwargs["headers"] + assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world" + assert ( + mock_client_post.call_args.kwargs["headers"]["Authorization"] + == "my-test-key" + ) + mock_client_post.assert_called_once() + + @pytest.mark.asyncio async def test_bedrock_custom_prompt_template(): """