fix(base_invoke_transformation.py): support extra_headers on bedrock invoke route

Fixes https://github.com/BerriAI/litellm/issues/9106
This commit is contained in:
Krrish Dholakia 2025-03-10 16:13:11 -07:00
parent ea058ab4ea
commit 68bd05ac24
3 changed files with 53 additions and 12 deletions

View file

@ -129,7 +129,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
extra_headers = optional_params.get("extra_headers", None)
aws_secret_access_key = optional_params.get("aws_secret_access_key", None) aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
aws_access_key_id = optional_params.get("aws_access_key_id", None) aws_access_key_id = optional_params.get("aws_access_key_id", None)
aws_session_token = optional_params.get("aws_session_token", None) aws_session_token = optional_params.get("aws_session_token", None)
@ -155,9 +154,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
) )
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"} if headers is not None:
if extra_headers is not None: headers = {"Content-Type": "application/json", **headers}
headers = {"Content-Type": "application/json", **extra_headers} else:
headers = {"Content-Type": "application/json"}
request = AWSRequest( request = AWSRequest(
method="POST", method="POST",
@ -166,12 +166,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
headers=headers, headers=headers,
) )
sigv4.add_auth(request) sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
return dict(request.headers) request_headers_dict = dict(request.headers)
if (
headers is not None and "Authorization" in headers
): # prevent sigv4 from overwriting the auth header
request_headers_dict["Authorization"] = headers["Authorization"]
return request_headers_dict
def transform_request( def transform_request(
self, self,
@ -443,7 +444,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:
return {} return headers
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]

View file

@ -873,7 +873,9 @@ class BaseLLMHTTPHandler:
elif isinstance(audio_file, bytes): elif isinstance(audio_file, bytes):
# Assume it's already binary data # Assume it's already binary data
binary_data = audio_file binary_data = audio_file
elif isinstance(audio_file, io.BufferedReader) or isinstance(audio_file, io.BytesIO): elif isinstance(audio_file, io.BufferedReader) or isinstance(
audio_file, io.BytesIO
):
# Handle file-like objects # Handle file-like objects
binary_data = audio_file.read() binary_data = audio_file.read()

View file

@ -956,7 +956,7 @@ def test_bedrock_ptu():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_extra_headers(): async def test_bedrock_custom_api_base():
""" """
Check if a url with 'modelId' passed in, is created correctly Check if a url with 'modelId' passed in, is created correctly
@ -994,6 +994,44 @@ async def test_bedrock_extra_headers():
mock_client_post.assert_called_once() mock_client_post.assert_called_once()
@pytest.mark.parametrize(
"model",
[
"anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0",
],
)
@pytest.mark.asyncio
async def test_bedrock_extra_headers(model):
"""
Relevant Issue: https://github.com/BerriAI/litellm/issues/9106
"""
client = AsyncHTTPHandler()
with patch.object(client, "post", new=AsyncMock()) as mock_client_post:
litellm.set_verbose = True
from openai.types.chat import ChatCompletion
try:
response = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "What's AWS?"}],
client=client,
extra_headers={"test": "hello world", "Authorization": "my-test-key"},
)
except Exception as e:
print(f"error: {e}")
print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
assert (
mock_client_post.call_args.kwargs["headers"]["Authorization"]
== "my-test-key"
)
mock_client_post.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_custom_prompt_template(): async def test_bedrock_custom_prompt_template():
""" """