forked from phoenix/litellm-mirror
test(test_bedrock_completion.py): refactor test bedrock headers test
This commit is contained in:
parent
cac63fa0d9
commit
24eb79da91
1 changed files with 28 additions and 19 deletions
|
@ -239,25 +239,6 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_bedrock_extra_headers():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.78,
|
||||
extra_headers={"x-key": "x_key_value"},
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
assert len(response.choices) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_bedrock_claude_3():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
@ -536,3 +517,31 @@ def test_bedrock_ptu():
|
|||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
|
||||
)
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
||||
def test_bedrock_extra_headers():
|
||||
"""
|
||||
Check if a url with 'modelId' passed in, is created correctly
|
||||
|
||||
Reference: https://github.com/BerriAI/litellm/issues/3805
|
||||
"""
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||
litellm.set_verbose = True
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=[{"role": "user", "content": "What's AWS?"}],
|
||||
client=client,
|
||||
extra_headers={"test": "hello world"},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||
assert "test" in mock_client_post.call_args.kwargs["headers"]
|
||||
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
|
||||
mock_client_post.assert_called_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue