test(test_bedrock_completion.py): refactor test bedrock headers test

This commit is contained in:
Krrish Dholakia 2024-05-27 19:00:55 -07:00
parent cac63fa0d9
commit 24eb79da91

View file

@ -239,25 +239,6 @@ def test_completion_bedrock_claude_sts_oidc_auth():
pytest.fail(f"Error occurred: {e}")
def test_bedrock_extra_headers():
try:
litellm.set_verbose = True
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
max_tokens=10,
temperature=0.78,
extra_headers={"x-key": "x_key_value"},
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
try:
litellm.set_verbose = True
@ -536,3 +517,31 @@ def test_bedrock_ptu():
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
)
mock_client_post.assert_called_once()
def test_bedrock_extra_headers():
"""
Check if a url with 'modelId' passed in, is created correctly
Reference: https://github.com/BerriAI/litellm/issues/3805
"""
client = HTTPHandler()
with patch.object(client, "post", new=Mock()) as mock_client_post:
litellm.set_verbose = True
from openai.types.chat import ChatCompletion
try:
response = litellm.completion(
model="anthropic.claude-3-sonnet-20240229-v1:0",
messages=[{"role": "user", "content": "What's AWS?"}],
client=client,
extra_headers={"test": "hello world"},
)
except Exception as e:
pass
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
mock_client_post.assert_called_once()