mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test: mock sagemaker tests
This commit is contained in:
parent
58f46d847c
commit
48e6a7036b
3 changed files with 56 additions and 21 deletions
|
@ -5,6 +5,7 @@ from typing import Callable, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import ModelResponse, get_secret
|
from litellm.utils import ModelResponse, get_secret
|
||||||
|
|
||||||
from ..common_utils import AWSEventStreamDecoder
|
from ..common_utils import AWSEventStreamDecoder
|
||||||
|
@ -125,6 +126,7 @@ class SagemakerChatHandler(BaseAWSLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||||
|
@ -173,4 +175,5 @@ class SagemakerChatHandler(BaseAWSLLM):
|
||||||
custom_endpoint=True,
|
custom_endpoint=True,
|
||||||
custom_llm_provider="sagemaker_chat",
|
custom_llm_provider="sagemaker_chat",
|
||||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2604,6 +2604,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
|
|
@ -8,7 +8,7 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import litellm
|
||||||
from test_streaming import streaming_format_tests
|
from test_streaming import streaming_format_tests
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -96,7 +96,12 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
print("testing sagemaker")
|
print("testing sagemaker")
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
resp = litellm.completion(
|
resp = litellm.completion(
|
||||||
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -104,9 +109,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
],
|
],
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=80,
|
max_tokens=80,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
print(resp)
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert (
|
||||||
|
json_data["model"]
|
||||||
|
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
|
||||||
|
)
|
||||||
|
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
|
||||||
|
assert json_data["temperature"] == 0.2
|
||||||
|
assert json_data["max_tokens"] == 80
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
resp = await litellm.acompletion(
|
resp = await litellm.acompletion(
|
||||||
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -114,8 +134,19 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
],
|
],
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=80,
|
max_tokens=80,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
print(resp)
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert (
|
||||||
|
json_data["model"]
|
||||||
|
== "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245"
|
||||||
|
)
|
||||||
|
assert json_data["messages"] == [{"role": "user", "content": "hi"}]
|
||||||
|
assert json_data["temperature"] == 0.2
|
||||||
|
assert json_data["max_tokens"] == 80
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -125,7 +156,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -185,7 +216,7 @@ async def test_completion_sagemaker_stream(sync_mode, model):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
# "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
|
||||||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue