test: mock sagemaker tests

This commit is contained in:
Krrish Dholakia 2025-03-21 16:18:02 -07:00
parent 58f46d847c
commit 48e6a7036b
3 changed files with 56 additions and 21 deletions

View file

@ -5,6 +5,7 @@ from typing import Callable, Optional, Union
import httpx
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import ModelResponse, get_secret
from ..common_utils import AWSEventStreamDecoder
@ -125,6 +126,7 @@ class SagemakerChatHandler(BaseAWSLLM):
logger_fn=None,
acompletion: bool = False,
headers: dict = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
@ -173,4 +175,5 @@ class SagemakerChatHandler(BaseAWSLLM):
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
client=client,
)

View file

@ -2604,6 +2604,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
client=client,
)
## RESPONSE OBJECT

View file

@ -8,7 +8,7 @@ from dotenv import load_dotenv
load_dotenv()
import io
import os
import litellm
from test_streaming import streaming_format_tests
sys.path.insert(
@ -96,26 +96,57 @@ async def test_completion_sagemaker_messages_api(sync_mode):
litellm.set_verbose = True
verbose_logger.setLevel(logging.DEBUG)
print("testing sagemaker")
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
if sync_mode is True:
resp = litellm.completion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
)
print(resp)
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = litellm.completion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
else:
resp = await litellm.acompletion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
)
print(resp)
client = AsyncHTTPHandler()
with patch.object(client, "post") as mock_post:
try:
resp = await litellm.acompletion(
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
json_data["model"]
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
)
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
assert json_data["temperature"] == 0.2
assert json_data["max_tokens"] == 80
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -125,7 +156,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
@pytest.mark.parametrize(
"model",
[
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
@ -185,7 +216,7 @@ async def test_completion_sagemaker_stream(sync_mode, model):
@pytest.mark.parametrize(
"model",
[
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)