mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(bedrock_httpx.py): support 'Auth' header as extra_header
Fixes https://github.com/BerriAI/litellm/issues/5389#issuecomment-2313677977
This commit is contained in:
parent
1b2f73c220
commit
6431af0678
4 changed files with 25 additions and 2 deletions
|
@ -894,6 +894,10 @@ class BedrockLLM(BaseAWSLLM):
|
|||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
|
@ -1659,6 +1663,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -196,6 +196,11 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
|
||||
prepped_request = request.prepare()
|
||||
|
||||
return prepped_request
|
||||
|
|
|
@ -182,6 +182,11 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
|
||||
prepped_request = request.prepare()
|
||||
|
||||
return prepped_request
|
||||
|
|
|
@ -945,7 +945,8 @@ async def test_bedrock_extra_headers():
|
|||
"""
|
||||
Check if a url with 'modelId' passed in, is created correctly
|
||||
|
||||
Reference: https://github.com/BerriAI/litellm/issues/3805
|
||||
Reference: https://github.com/BerriAI/litellm/issues/3805, https://github.com/BerriAI/litellm/issues/5389#issuecomment-2313677977
|
||||
|
||||
"""
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
|
@ -958,7 +959,7 @@ async def test_bedrock_extra_headers():
|
|||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=[{"role": "user", "content": "What's AWS?"}],
|
||||
client=client,
|
||||
extra_headers={"test": "hello world"},
|
||||
extra_headers={"test": "hello world", "Authorization": "my-test-key"},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
@ -966,6 +967,10 @@ async def test_bedrock_extra_headers():
|
|||
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||
assert "test" in mock_client_post.call_args.kwargs["headers"]
|
||||
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
|
||||
assert (
|
||||
mock_client_post.call_args.kwargs["headers"]["Authorization"]
|
||||
== "my-test-key"
|
||||
)
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue