diff --git a/tests/local_testing/test_sagemaker.py b/tests/local_testing/test_sagemaker.py index 8438c3c6ba..ba1ab11596 100644 --- a/tests/local_testing/test_sagemaker.py +++ b/tests/local_testing/test_sagemaker.py @@ -265,7 +265,7 @@ async def test_acompletion_sagemaker_non_stream(): # Assert mock_post.assert_called_once() _, kwargs = mock_post.call_args - args_to_sagemaker = kwargs["json"] + args_to_sagemaker = json.loads(kwargs["data"]) print("Arguments passed to sagemaker=", args_to_sagemaker) assert args_to_sagemaker == expected_payload assert ( @@ -325,7 +325,7 @@ async def test_completion_sagemaker_non_stream(): # Assert mock_post.assert_called_once() _, kwargs = mock_post.call_args - args_to_sagemaker = kwargs["json"] + args_to_sagemaker = json.loads(kwargs["data"]) print("Arguments passed to sagemaker=", args_to_sagemaker) assert args_to_sagemaker == expected_payload assert ( @@ -386,7 +386,7 @@ async def test_completion_sagemaker_prompt_template_non_stream(): # Assert mock_post.assert_called_once() _, kwargs = mock_post.call_args - args_to_sagemaker = kwargs["json"] + args_to_sagemaker = json.loads(kwargs["data"]) print("Arguments passed to sagemaker=", args_to_sagemaker) assert args_to_sagemaker == expected_payload @@ -445,7 +445,7 @@ async def test_completion_sagemaker_non_stream_with_aws_params(): # Assert mock_post.assert_called_once() _, kwargs = mock_post.call_args - args_to_sagemaker = kwargs["json"] + args_to_sagemaker = json.loads(kwargs["data"]) print("Arguments passed to sagemaker=", args_to_sagemaker) assert args_to_sagemaker == expected_payload assert (