From 24eb79da919b9e352467736c72c5cbfc6d8adb1b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 27 May 2024 19:00:55 -0700 Subject: [PATCH] test(test_bedrock_completion.py): refactor test bedrock headers test --- litellm/tests/test_bedrock_completion.py | 47 ++++++++++++++---------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 6ffc1a4c4..8b13765da 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -239,25 +239,6 @@ def test_completion_bedrock_claude_sts_oidc_auth(): pytest.fail(f"Error occurred: {e}") -def test_bedrock_extra_headers(): - try: - litellm.set_verbose = True - response: ModelResponse = completion( - model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", - messages=messages, - max_tokens=10, - temperature=0.78, - extra_headers={"x-key": "x_key_value"}, - ) - # Add any assertions here to check the response - assert len(response.choices) > 0 - assert len(response.choices[0].message.content) > 0 - except RateLimitError: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - def test_bedrock_claude_3(): try: litellm.set_verbose = True @@ -536,3 +517,31 @@ def test_bedrock_ptu(): == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" ) mock_client_post.assert_called_once() + + +def test_bedrock_extra_headers(): + """ + Check if a url with 'modelId' passed in, is created correctly + + Reference: https://github.com/BerriAI/litellm/issues/3805 + """ + client = HTTPHandler() + + with patch.object(client, "post", new=Mock()) as mock_client_post: + litellm.set_verbose = True + from openai.types.chat import ChatCompletion + + try: + response = litellm.completion( + model="anthropic.claude-3-sonnet-20240229-v1:0", + messages=[{"role": "user", "content": "What's AWS?"}], + client=client, + extra_headers={"test": "hello world"}, + ) + except Exception as e: + pass + + print(f"mock_client_post.call_args: {mock_client_post.call_args}") + assert "test" in mock_client_post.call_args.kwargs["headers"] + assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world" + mock_client_post.assert_called_once()