litellm-mirror/tests/local_testing/test_sagemaker.py
2025-03-21 16:21:18 -07:00

485 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
import litellm
from test_streaming import streaming_format_tests
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries =3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
import logging
from litellm._logging import verbose_logger
def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}")
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_completion_sagemaker(sync_mode):
try:
litellm.set_verbose = True
verbose_logger.setLevel(logging.DEBUG)
print("testing sagemaker")
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print("calculated cost", cost)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single completion call
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
async def test_completion_sagemaker_messages_api(sync_mode):
try:
litellm.set_verbose = True
verbose_logger.setLevel(logging.DEBUG)
print("testing sagemaker")
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
if sync_mode is True:
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = litellm.completion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
else:
client = AsyncHTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = await litellm.acompletion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.parametrize(
"model",
[
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
# @pytest.mark.flaky(retries=3, delay=1)
async def test_completion_sagemaker_stream(sync_mode, model):
try:
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
temperature=0.2,
stream=True,
max_tokens=80,
input_cost_per_second=0.000420,
)
for idx, chunk in enumerate(response):
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
stream=True,
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
print("streaming response")
idx = 0
async for chunk in response:
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
idx += 1
print("ASYNC RESPONSE full text", full_text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.parametrize(
"model",
[
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
async def test_completion_sagemaker_streaming_bad_request(sync_mode, model):
litellm.set_verbose = True
print("testing sagemaker")
if sync_mode is True:
with pytest.raises(litellm.BadRequestError):
response = litellm.completion(
model=model,
messages=[
{"role": "user", "content": "hi"},
],
stream=True,
max_tokens=8000000000000000,
)
else:
with pytest.raises(litellm.BadRequestError):
response = await litellm.acompletion(
model=model,
messages=[
{"role": "user", "content": "hi"},
],
stream=True,
max_tokens=8000000000000000,
)
@pytest.mark.asyncio
async def test_acompletion_sagemaker_non_stream():
mock_response = AsyncMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = json.loads(kwargs["data"])
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = json.loads(kwargs["data"])
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_completion_sagemaker_prompt_template_non_stream():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n\n### Instruction:\nhi\n\n\n### Response:\n",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/deepseek_coder_6.7_instruct",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
hf_model_name="deepseek-ai/deepseek-coder-6.7b-instruct",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = json.loads(kwargs["data"])
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream_with_aws_params():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "hi",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
aws_access_key_id="gm",
aws_secret_access_key="s",
aws_region_name="us-west-5",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = json.loads(kwargs["data"])
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
assert (
kwargs["url"]
== "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations"
)