mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
250 lines
7.9 KiB
Python
250 lines
7.9 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
import io
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
|
|
|
# litellm.num_retries =3
|
|
litellm.cache = None
|
|
litellm.success_callback = []
|
|
user_message = "Write a short poem about the sky"
|
|
messages = [{"content": user_message, "role": "user"}]
|
|
import logging
|
|
|
|
from litellm._logging import verbose_logger
|
|
|
|
|
|
def logger_fn(user_model_dict):
|
|
print(f"user_model_dict: {user_model_dict}")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_callbacks():
|
|
print("\npytest fixture - resetting callbacks")
|
|
litellm.success_callback = []
|
|
litellm._async_success_callback = []
|
|
litellm.failure_callback = []
|
|
litellm.callbacks = []
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
async def test_completion_sagemaker(sync_mode):
|
|
try:
|
|
litellm.set_verbose = True
|
|
print("testing sagemaker")
|
|
if sync_mode is True:
|
|
response = litellm.completion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
temperature=0.2,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
else:
|
|
response = await litellm.acompletion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
temperature=0.2,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
# Add any assertions here to check the response
|
|
print(response)
|
|
cost = completion_cost(completion_response=response)
|
|
print("calculated cost", cost)
|
|
assert (
|
|
cost > 0.0 and cost < 1.0
|
|
) # should never be > $1 for a single completion call
|
|
except Exception as e:
|
|
pytest.fail(f"Error occurred: {e}")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
|
async def test_completion_sagemaker_stream(sync_mode):
|
|
try:
|
|
litellm.set_verbose = False
|
|
print("testing sagemaker")
|
|
verbose_logger.setLevel(logging.DEBUG)
|
|
full_text = ""
|
|
if sync_mode is True:
|
|
response = litellm.completion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi - what is ur name"},
|
|
],
|
|
temperature=0.2,
|
|
stream=True,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
|
|
for chunk in response:
|
|
print(chunk)
|
|
full_text += chunk.choices[0].delta.content or ""
|
|
|
|
print("SYNC RESPONSE full text", full_text)
|
|
else:
|
|
response = await litellm.acompletion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi - what is ur name"},
|
|
],
|
|
stream=True,
|
|
temperature=0.2,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
|
|
print("streaming response")
|
|
|
|
async for chunk in response:
|
|
print(chunk)
|
|
full_text += chunk.choices[0].delta.content or ""
|
|
|
|
print("ASYNC RESPONSE full text", full_text)
|
|
|
|
except Exception as e:
|
|
pytest.fail(f"Error occurred: {e}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acompletion_sagemaker_non_stream():
|
|
mock_response = AsyncMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"generated_text": "This is a mock response from SageMaker.",
|
|
"id": "cmpl-mockid",
|
|
"object": "text_completion",
|
|
"created": 1629800000,
|
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
"choices": [
|
|
{
|
|
"text": "This is a mock response from SageMaker.",
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "length",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
|
|
expected_payload = {
|
|
"inputs": "hi",
|
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
|
}
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
# Act: Call the litellm.acompletion function
|
|
response = await litellm.acompletion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
temperature=0.2,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
|
|
# Print what was called on the mock
|
|
print("call args=", mock_post.call_args)
|
|
|
|
# Assert
|
|
mock_post.assert_called_once()
|
|
_, kwargs = mock_post.call_args
|
|
args_to_sagemaker = kwargs["json"]
|
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
|
assert args_to_sagemaker == expected_payload
|
|
assert (
|
|
kwargs["url"]
|
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_sagemaker_non_stream():
|
|
mock_response = MagicMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"generated_text": "This is a mock response from SageMaker.",
|
|
"id": "cmpl-mockid",
|
|
"object": "text_completion",
|
|
"created": 1629800000,
|
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
"choices": [
|
|
{
|
|
"text": "This is a mock response from SageMaker.",
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "length",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
|
|
expected_payload = {
|
|
"inputs": "hi",
|
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
|
}
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
# Act: Call the litellm.acompletion function
|
|
response = litellm.completion(
|
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
|
messages=[
|
|
{"role": "user", "content": "hi"},
|
|
],
|
|
temperature=0.2,
|
|
max_tokens=80,
|
|
input_cost_per_second=0.000420,
|
|
)
|
|
|
|
# Print what was called on the mock
|
|
print("call args=", mock_post.call_args)
|
|
|
|
# Assert
|
|
mock_post.assert_called_once()
|
|
_, kwargs = mock_post.call_args
|
|
args_to_sagemaker = kwargs["json"]
|
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
|
assert args_to_sagemaker == expected_payload
|
|
assert (
|
|
kwargs["url"]
|
|
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
|
|
)
|