mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(base_invoke_transformation.py): support extra_headers on bedrock invoke route
Fixes https://github.com/BerriAI/litellm/issues/9106
This commit is contained in:
parent
ea058ab4ea
commit
68bd05ac24
3 changed files with 53 additions and 12 deletions
|
@ -129,7 +129,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
|
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
extra_headers = optional_params.get("extra_headers", None)
|
|
||||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||||
aws_session_token = optional_params.get("aws_session_token", None)
|
aws_session_token = optional_params.get("aws_session_token", None)
|
||||||
|
@ -155,9 +154,10 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
if headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **headers}
|
||||||
|
else:
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if extra_headers is not None:
|
|
||||||
headers = {"Content-Type": "application/json", **extra_headers}
|
|
||||||
|
|
||||||
request = AWSRequest(
|
request = AWSRequest(
|
||||||
method="POST",
|
method="POST",
|
||||||
|
@ -166,12 +166,13 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
sigv4.add_auth(request)
|
sigv4.add_auth(request)
|
||||||
if (
|
|
||||||
extra_headers is not None and "Authorization" in extra_headers
|
|
||||||
): # prevent sigv4 from overwriting the auth header
|
|
||||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
|
||||||
|
|
||||||
return dict(request.headers)
|
request_headers_dict = dict(request.headers)
|
||||||
|
if (
|
||||||
|
headers is not None and "Authorization" in headers
|
||||||
|
): # prevent sigv4 from overwriting the auth header
|
||||||
|
request_headers_dict["Authorization"] = headers["Authorization"]
|
||||||
|
return request_headers_dict
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
|
@ -443,7 +444,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {}
|
return headers
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
|
|
@ -873,7 +873,9 @@ class BaseLLMHTTPHandler:
|
||||||
elif isinstance(audio_file, bytes):
|
elif isinstance(audio_file, bytes):
|
||||||
# Assume it's already binary data
|
# Assume it's already binary data
|
||||||
binary_data = audio_file
|
binary_data = audio_file
|
||||||
elif isinstance(audio_file, io.BufferedReader) or isinstance(audio_file, io.BytesIO):
|
elif isinstance(audio_file, io.BufferedReader) or isinstance(
|
||||||
|
audio_file, io.BytesIO
|
||||||
|
):
|
||||||
# Handle file-like objects
|
# Handle file-like objects
|
||||||
binary_data = audio_file.read()
|
binary_data = audio_file.read()
|
||||||
|
|
||||||
|
|
|
@ -956,7 +956,7 @@ def test_bedrock_ptu():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_extra_headers():
|
async def test_bedrock_custom_api_base():
|
||||||
"""
|
"""
|
||||||
Check if a url with 'modelId' passed in, is created correctly
|
Check if a url with 'modelId' passed in, is created correctly
|
||||||
|
|
||||||
|
@ -994,6 +994,44 @@ async def test_bedrock_extra_headers():
|
||||||
mock_client_post.assert_called_once()
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"bedrock/invoke/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_extra_headers(model):
|
||||||
|
"""
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/9106
|
||||||
|
"""
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post", new=AsyncMock()) as mock_client_post:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "What's AWS?"}],
|
||||||
|
client=client,
|
||||||
|
extra_headers={"test": "hello world", "Authorization": "my-test-key"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: {e}")
|
||||||
|
|
||||||
|
print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
|
||||||
|
assert "test" in mock_client_post.call_args.kwargs["headers"]
|
||||||
|
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
|
||||||
|
assert (
|
||||||
|
mock_client_post.call_args.kwargs["headers"]["Authorization"]
|
||||||
|
== "my-test-key"
|
||||||
|
)
|
||||||
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_custom_prompt_template():
|
async def test_bedrock_custom_prompt_template():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue