test: mock sagemaker tests

This commit is contained in:
Krrish Dholakia 2025-03-21 16:18:02 -07:00
parent 58f46d847c
commit 48e6a7036b
3 changed files with 56 additions and 21 deletions

View file

@ -5,6 +5,7 @@ from typing import Callable, Optional, Union
import httpx import httpx
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret
from ..common_utils import AWSEventStreamDecoder from ..common_utils import AWSEventStreamDecoder
@ -125,6 +126,7 @@ class SagemakerChatHandler(BaseAWSLLM):
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
headers: dict = {}, headers: dict = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
@ -173,4 +175,5 @@ class SagemakerChatHandler(BaseAWSLLM):
custom_endpoint=True, custom_endpoint=True,
custom_llm_provider="sagemaker_chat", custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore streaming_decoder=custom_stream_decoder, # type: ignore
client=client,
) )

View file

@ -2604,6 +2604,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
client=client,
) )
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -8,7 +8,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
import io import io
import os import os
import litellm
from test_streaming import streaming_format_tests from test_streaming import streaming_format_tests
sys.path.insert( sys.path.insert(
@ -96,7 +96,12 @@ async def test_completion_sagemaker_messages_api(sync_mode):
litellm.set_verbose = True litellm.set_verbose = True
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)
print("testing sagemaker") print("testing sagemaker")
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
if sync_mode is True: if sync_mode is True:
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = litellm.completion( resp = litellm.completion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[ messages=[
@ -104,9 +109,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
], ],
temperature=0.2, temperature=0.2,
max_tokens=80, max_tokens=80,
client=client,
) )
print(resp) except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
else: else:
client = AsyncHTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = await litellm.acompletion( resp = await litellm.acompletion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[ messages=[
@ -114,8 +134,19 @@ async def test_completion_sagemaker_messages_api(sync_mode):
], ],
temperature=0.2, temperature=0.2,
max_tokens=80, max_tokens=80,
client=client,
) )
print(resp) except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -125,7 +156,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", # "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
], ],
) )
@ -185,7 +216,7 @@ async def test_completion_sagemaker_stream(sync_mode, model):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", # "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
], ],
) )