fix(bedrock_httpx.py): support 'Auth' header as extra_header

Fixes https://github.com/BerriAI/litellm/issues/5389#issuecomment-2313677977
This commit is contained in:
Krrish Dholakia 2024-08-27 16:08:54 -07:00
parent faf04985d6
commit 722ccba323
4 changed files with 25 additions and 2 deletions

View file

@ -894,6 +894,10 @@ class BedrockLLM(BaseAWSLLM):
method="POST", url=endpoint_url, data=data, headers=headers method="POST", url=endpoint_url, data=data, headers=headers
) )
sigv4.add_auth(request) sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare() prepped = request.prepare()
## LOGGING ## LOGGING
@ -1659,6 +1663,10 @@ class BedrockConverseLLM(BaseAWSLLM):
method="POST", url=endpoint_url, data=data, headers=headers method="POST", url=endpoint_url, data=data, headers=headers
) )
sigv4.add_auth(request) sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare() prepped = request.prepare()
## LOGGING ## LOGGING

View file

@ -196,6 +196,11 @@ class SagemakerLLM(BaseAWSLLM):
method="POST", url=api_base, data=encoded_data, headers=headers method="POST", url=api_base, data=encoded_data, headers=headers
) )
sigv4.add_auth(request) sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped_request = request.prepare() prepped_request = request.prepare()
return prepped_request return prepped_request

View file

@ -182,6 +182,11 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
method="POST", url=api_base, data=encoded_data, headers=headers method="POST", url=api_base, data=encoded_data, headers=headers
) )
sigv4.add_auth(request) sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped_request = request.prepare() prepped_request = request.prepare()
return prepped_request return prepped_request

View file

@ -945,7 +945,8 @@ async def test_bedrock_extra_headers():
""" """
Check if a url with 'modelId' passed in, is created correctly Check if a url with 'modelId' passed in, is created correctly
Reference: https://github.com/BerriAI/litellm/issues/3805 Reference: https://github.com/BerriAI/litellm/issues/3805, https://github.com/BerriAI/litellm/issues/5389#issuecomment-2313677977
""" """
client = AsyncHTTPHandler() client = AsyncHTTPHandler()
@ -958,7 +959,7 @@ async def test_bedrock_extra_headers():
model="anthropic.claude-3-sonnet-20240229-v1:0", model="anthropic.claude-3-sonnet-20240229-v1:0",
messages=[{"role": "user", "content": "What's AWS?"}], messages=[{"role": "user", "content": "What's AWS?"}],
client=client, client=client,
extra_headers={"test": "hello world"}, extra_headers={"test": "hello world", "Authorization": "my-test-key"},
) )
except Exception as e: except Exception as e:
pass pass
@ -966,6 +967,10 @@ async def test_bedrock_extra_headers():
print(f"mock_client_post.call_args: {mock_client_post.call_args}") print(f"mock_client_post.call_args: {mock_client_post.call_args}")
assert "test" in mock_client_post.call_args.kwargs["headers"] assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world" assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
assert (
mock_client_post.call_args.kwargs["headers"]["Authorization"]
== "my-test-key"
)
mock_client_post.assert_called_once() mock_client_post.assert_called_once()