diff --git a/litellm/llms/sagemaker/chat/handler.py b/litellm/llms/sagemaker/chat/handler.py index 3a90a15093..c827a8a5f7 100644 --- a/litellm/llms/sagemaker/chat/handler.py +++ b/litellm/llms/sagemaker/chat/handler.py @@ -5,6 +5,7 @@ from typing import Callable, Optional, Union import httpx from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import ModelResponse, get_secret from ..common_utils import AWSEventStreamDecoder @@ -125,6 +126,7 @@ class SagemakerChatHandler(BaseAWSLLM): logger_fn=None, acompletion: bool = False, headers: dict = {}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker @@ -173,4 +175,5 @@ class SagemakerChatHandler(BaseAWSLLM): custom_endpoint=True, custom_llm_provider="sagemaker_chat", streaming_decoder=custom_stream_decoder, # type: ignore + client=client, ) diff --git a/litellm/main.py b/litellm/main.py index 6cc1057bb4..1826f2df78 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2604,6 +2604,7 @@ def completion( # type: ignore # noqa: PLR0915 encoding=encoding, logging_obj=logging, acompletion=acompletion, + client=client, ) ## RESPONSE OBJECT diff --git a/tests/local_testing/test_sagemaker.py b/tests/local_testing/test_sagemaker.py index ba1ab11596..9c7161e4ae 100644 --- a/tests/local_testing/test_sagemaker.py +++ b/tests/local_testing/test_sagemaker.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv load_dotenv() import io import os - +import litellm from test_streaming import streaming_format_tests sys.path.insert( @@ -96,26 +96,57 @@ async def test_completion_sagemaker_messages_api(sync_mode): litellm.set_verbose = True verbose_logger.setLevel(logging.DEBUG) print("testing sagemaker") + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + if sync_mode is True: - resp = litellm.completion( - model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", - messages=[ - {"role": "user", "content": "hi"}, - ], - temperature=0.2, - max_tokens=80, - ) - print(resp) + client = HTTPHandler() + with patch.object(client, "post") as mock_post: + try: + resp = litellm.completion( + model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + client=client, + ) + except Exception as e: + print(e) + mock_post.assert_called_once() + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert ( + json_data["model"] + == "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245" + ) + assert json_data["messages"] == [{"role": "user", "content": "hi"}] + assert json_data["temperature"] == 0.2 + assert json_data["max_tokens"] == 80 + else: - resp = await litellm.acompletion( - model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", - messages=[ - {"role": "user", "content": "hi"}, - ], - temperature=0.2, - max_tokens=80, - ) - print(resp) + client = AsyncHTTPHandler() + with patch.object(client, "post") as mock_post: + try: + resp = await litellm.acompletion( + model="sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + client=client, + ) + except Exception as e: + print(e) + mock_post.assert_called_once() + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert ( + json_data["model"] + == "huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245" + ) + assert json_data["messages"] == [{"role": "user", "content": "hi"}] + assert json_data["temperature"] == 0.2 + assert json_data["max_tokens"] == 80 except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -125,7 +156,7 @@ async def test_completion_sagemaker_messages_api(sync_mode): @pytest.mark.parametrize( "model", [ - "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + # "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", ], ) @@ -185,7 +216,7 @@ async def test_completion_sagemaker_stream(sync_mode, model): @pytest.mark.parametrize( "model", [ - "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + # "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", ], )