mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
add non-stream mock tests for sagemaker
This commit is contained in:
parent
e217eda303
commit
b58c2bef1c
1 changed files with 127 additions and 0 deletions
127
litellm/tests/test_sagemaker.py
Normal file
127
litellm/tests/test_sagemaker.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
# litellm.num_retries =3
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
def logger_fn(user_model_dict):
|
||||
print(f"user_model_dict: {user_model_dict}")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_callbacks():
|
||||
print("\npytest fixture - resetting callbacks")
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.callbacks = []
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_completion_sagemaker():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("testing sagemaker")
|
||||
response = await litellm.acompletion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
input_cost_per_second=0.000420,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
cost = completion_cost(completion_response=response)
|
||||
print("calculated cost", cost)
|
||||
assert (
|
||||
cost > 0.0 and cost < 1.0
|
||||
) # should never be > $1 for a single completion call
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_sagemaker_non_stream():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"generated_text": "This is a mock response from SageMaker.",
|
||||
"id": "cmpl-mockid",
|
||||
"object": "text_completion",
|
||||
"created": 1629800000,
|
||||
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a mock response from SageMaker.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
|
||||
expected_payload = {
|
||||
"inputs": "hi",
|
||||
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
# Act: Call the litellm.acompletion function
|
||||
response = await litellm.acompletion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
input_cost_per_second=0.000420,
|
||||
)
|
||||
|
||||
# Print what was called on the mock
|
||||
print("call args=", mock_post.call_args)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
_, kwargs = mock_post.call_args
|
||||
args_to_sagemaker = kwargs["json"]
|
||||
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||
assert args_to_sagemaker == expected_payload
|
||||
assert (
|
||||
kwargs["url"]
|
||||
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue