diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py new file mode 100644 index 0000000000..831ec5a2a8 --- /dev/null +++ b/litellm/tests/test_sagemaker.py @@ -0,0 +1,127 @@ +import json +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.prompt_templates.factory import anthropic_messages_pt + +# litellm.num_retries =3 +litellm.cache = None +litellm.success_callback = [] +user_message = "Write a short poem about the sky" +messages = [{"content": user_message, "role": "user"}] + + +def logger_fn(user_model_dict): + print(f"user_model_dict: {user_model_dict}") + + +@pytest.fixture(autouse=True) +def reset_callbacks(): + print("\npytest fixture - resetting callbacks") + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm.callbacks = [] + + +@pytest.mark.asyncio() +async def test_completion_sagemaker(): + try: + litellm.set_verbose = True + print("testing sagemaker") + response = await litellm.acompletion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + # Add any assertions here to check the response + print(response) + cost = completion_cost(completion_response=response) + print("calculated cost", cost) + assert ( + cost > 0.0 and cost < 1.0 + ) # should never be > $1 for a single completion call + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_acompletion_sagemaker_non_stream(): + mock_response = AsyncMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + + expected_payload = { + "inputs": "hi", + "parameters": {"temperature": 0.2, "max_new_tokens": 80}, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = await litellm.acompletion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + assert args_to_sagemaker == expected_payload + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" + )