From 6715905994ccd3420444dedb2a3f26f983da140d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 21:51:00 -0800 Subject: [PATCH] test_pass_through_request_logging_failure --- .../test_pass_through_unit_tests.py | 82 ++++++++++++++++--- 1 file changed, 70 insertions(+), 12 deletions(-) diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index 287a44e69..c564c14d2 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -202,10 +202,29 @@ async def test_pass_through_request_logging_failure( async def mock_logging_failure(*args, **kwargs): raise Exception("Logging failed!") - # Patch only the logging handler + # Create a mock response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + + # Add mock content + mock_response._content = b'{"mock": "response"}' + + async def mock_aread(): + return mock_response._content + + mock_response.aread = mock_aread + + # Patch both the logging handler and the httpx client with patch( "litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler", new=mock_logging_failure, + ), patch( + "httpx.AsyncClient.send", + return_value=mock_response, + ), patch( + "httpx.AsyncClient.request", + return_value=mock_response, ): request = mock_request( headers={}, method="POST", request_body=athropic_request_body @@ -219,8 +238,14 @@ async def test_pass_through_request_logging_failure( # Assert response was returned successfully despite logging failure assert response.status_code == 200 - print("response", response) - print(vars(response)) + + # Verify we got the mock response content + if hasattr(response, "body"): + content = response.body + else: + content = await response.aread() + + assert content == b'{"mock": "response"}' @pytest.mark.asyncio @@ -235,11 +260,38 @@ async def test_pass_through_request_logging_failure_with_stream( async def mock_logging_failure(*args, **kwargs): raise Exception("Logging failed!") - athropic_request_body["stream"] = True - # Patch only the logging handler + # Create a mock response + mock_response = AsyncMock() + mock_response.status_code = 200 + + # Add headers property to mock response + mock_response.headers = { + "content-type": "application/json", # Not streaming + } + + # Create mock chunks for streaming + mock_chunks = [b'{"chunk": 1}', b'{"chunk": 2}'] + mock_response.body_iterator = AsyncMock() + mock_response.body_iterator.__aiter__.return_value = mock_chunks + + # Add aread method to mock response + mock_response._content = b'{"mock": "response"}' + + async def mock_aread(): + return mock_response._content + + mock_response.aread = mock_aread + + # Patch both the logging handler and the httpx client with patch( "litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler", new=mock_logging_failure, + ), patch( + "httpx.AsyncClient.send", + return_value=mock_response, + ), patch( + "httpx.AsyncClient.request", + return_value=mock_response, ): request = mock_request( headers={}, method="POST", request_body=athropic_request_body @@ -254,11 +306,17 @@ async def test_pass_through_request_logging_failure_with_stream( # Assert response was returned successfully despite logging failure assert response.status_code == 200 - print(vars(response)) - print(dir(response)) - body_iterator = response.body_iterator - async for chunk in body_iterator: - assert chunk + # For non-streaming responses, we can access the content directly + if hasattr(response, "body"): + content = response.body + else: + # For streaming responses, we need to read the chunks + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + content = b"".join(chunks) - print("response", response) - print(vars(response)) + # Verify we got some response content + assert content is not None + if isinstance(content, bytes): + assert len(content) > 0