mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(base_invoke_transformation.py): support extra_headers on bedrock invoke route
Fixes https://github.com/BerriAI/litellm/issues/9106
This commit is contained in:
parent
ea058ab4ea
commit
68bd05ac24
3 changed files with 53 additions and 12 deletions
|
@ -129,7 +129,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
extra_headers = optional_params.get("extra_headers", None)
|
||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.get("aws_session_token", None)
|
||||
|
@ -155,9 +154,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
)
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
if headers is not None:
|
||||
headers = {"Content-Type": "application/json", **headers}
|
||||
else:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST",
|
||||
|
@ -166,12 +166,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
headers=headers,
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
|
||||
return dict(request.headers)
|
||||
request_headers_dict = dict(request.headers)
|
||||
if (
|
||||
headers is not None and "Authorization" in headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request_headers_dict["Authorization"] = headers["Authorization"]
|
||||
return request_headers_dict
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
|
@ -443,7 +444,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
|
|
|
@ -873,7 +873,9 @@ class BaseLLMHTTPHandler:
|
|||
elif isinstance(audio_file, bytes):
|
||||
# Assume it's already binary data
|
||||
binary_data = audio_file
|
||||
elif isinstance(audio_file, io.BufferedReader) or isinstance(audio_file, io.BytesIO):
|
||||
elif isinstance(audio_file, io.BufferedReader) or isinstance(
|
||||
audio_file, io.BytesIO
|
||||
):
|
||||
# Handle file-like objects
|
||||
binary_data = audio_file.read()
|
||||
|
||||
|
|
|
@ -956,7 +956,7 @@ def test_bedrock_ptu():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_extra_headers():
|
||||
async def test_bedrock_custom_api_base():
|
||||
"""
|
||||
Check if a url with 'modelId' passed in, is created correctly
|
||||
|
||||
|
@ -994,6 +994,44 @@ async def test_bedrock_extra_headers():
|
|||
mock_client_post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_extra_headers(model):
|
||||
"""
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/9106
|
||||
"""
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=AsyncMock()) as mock_client_post:
|
||||
litellm.set_verbose = True
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "What's AWS?"}],
|
||||
client=client,
|
||||
extra_headers={"test": "hello world", "Authorization": "my-test-key"},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error: {e}")
|
||||
|
||||
print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
|
||||
assert "test" in mock_client_post.call_args.kwargs["headers"]
|
||||
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
|
||||
assert (
|
||||
mock_client_post.call_args.kwargs["headers"]["Authorization"]
|
||||
== "my-test-key"
|
||||
)
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_custom_prompt_template():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue