From 68408c4d7741d915a19096c49bb5d7638a359b81 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 13:17:25 -0800 Subject: [PATCH 1/5] run pass through logging async --- .../pass_through_endpoints.py | 22 ++++++------ .../streaming_handler.py | 34 ++++++++++--------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 130be6303..7d0e3b29b 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -529,16 +529,18 @@ async def pass_through_request( # noqa: PLR0915 response_body: Optional[dict] = get_response_body(response) passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() - await pass_through_endpoint_logging.pass_through_async_success_handler( - httpx_response=response, - response_body=response_body, - url_route=str(url), - result="", - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - cache_hit=False, - **kwargs, + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + **kwargs, + ) ) return Response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index dc6aae3af..adfd49c78 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -58,15 +58,17 @@ class PassThroughStreamingHandler: # After all chunks are processed, handle post-processing end_time = datetime.now() - await PassThroughStreamingHandler._route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - raw_bytes=raw_bytes, - end_time=end_time, + asyncio.create_task( + PassThroughStreamingHandler._route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) ) except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") @@ -108,9 +110,9 @@ class PassThroughStreamingHandler: all_chunks=all_chunks, end_time=end_time, ) - standard_logging_response_object = anthropic_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: vertex_passthrough_logging_handler_result = ( @@ -125,9 +127,9 @@ class PassThroughStreamingHandler: end_time=end_time, ) ) - standard_logging_response_object = vertex_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) kwargs = vertex_passthrough_logging_handler_result["kwargs"] if standard_logging_response_object is None: @@ -168,4 +170,4 @@ class PassThroughStreamingHandler: # Split by newlines and filter out empty lines lines = [line.strip() for line in combined_str.split("\n") if line.strip()] - return lines \ No newline at end of file + return lines From 068f1af120aa1ad037c39b5e4e018ff2c7fe46fe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 16:13:06 -0800 Subject: [PATCH 2/5] fix use thread_pool_executor for pass through logging --- litellm/proxy/pass_through_endpoints/success_handler.py | 8 +++++--- litellm/proxy/proxy_config.yaml | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index c9c7707f0..b603510ff 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -18,6 +18,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu from litellm.proxy._types import PassThroughEndpointLoggingResultValues from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.utils import StandardPassThroughResponseObject +from litellm.utils import executor as thread_pool_executor from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, @@ -93,15 +94,16 @@ class PassThroughEndpointLogging: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) - threading.Thread( - target=logging_obj.success_handler, + thread_pool_executor.submit( + logging_obj.success_handler, args=( standard_logging_response_object, start_time, end_time, cache_hit, ), - ).start() + ) + await logging_obj.async_success_handler( result=( json.dumps(result) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 13fb1bcbe..40cd86c5c 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -21,4 +21,5 @@ router_settings: redis_password: os.environ/REDIS_PASSWORD litellm_settings: - callbacks: ["prometheus"] \ No newline at end of file + callbacks: ["prometheus"] + success_callback: ["langfuse"] \ No newline at end of file From 904ece6757fe580bbd9bdce945cba02131af2e9c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 16:51:53 -0800 Subject: [PATCH 3/5] test_pass_through_request_logging_failure_with_stream --- .../test_pass_through_unit_tests.py | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index c55bdc7a8..287a44e69 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -3,11 +3,13 @@ import os import sys from datetime import datetime from unittest.mock import AsyncMock, Mock, patch, MagicMock +from typing import Optional sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +import fastapi import httpx import pytest import litellm @@ -21,6 +23,9 @@ from litellm.proxy.pass_through_endpoints.streaming_handler import ( PassThroughStreamingHandler, ) +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, +) from fastapi import Request from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( @@ -33,9 +38,21 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin @pytest.fixture def mock_request(): # Create a mock request with headers + class QueryParams: + def __init__(self): + self._dict = {} + class MockRequest: - def __init__(self, headers=None): + def __init__( + self, headers=None, method="POST", request_body: Optional[dict] = None + ): self.headers = headers or {} + self.query_params = QueryParams() + self.method = method + self.request_body = request_body or {} + + async def body(self) -> bytes: + return bytes(json.dumps(self.request_body), "utf-8") return MockRequest @@ -163,3 +180,85 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict): metadata = result["litellm_params"]["metadata"] print("metadata", metadata) assert metadata["tags"] == ["tag1", "tag2"] + + +athropic_request_body = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}], + "litellm_metadata": {"tags": ["hi", "hello"]}, +} + + +@pytest.mark.asyncio +async def test_pass_through_request_logging_failure( + mock_request, mock_user_api_key_dict +): + """ + Test that pass_through_request still returns a response even if logging raises an Exception + """ + + # Mock the logging handler to raise an error + async def mock_logging_failure(*args, **kwargs): + raise Exception("Logging failed!") + + # Patch only the logging handler + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler", + new=mock_logging_failure, + ): + request = mock_request( + headers={}, method="POST", request_body=athropic_request_body + ) + response = await pass_through_request( + request=request, + target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages", + custom_headers={}, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert response was returned successfully despite logging failure + assert response.status_code == 200 + print("response", response) + print(vars(response)) + + +@pytest.mark.asyncio +async def test_pass_through_request_logging_failure_with_stream( + mock_request, mock_user_api_key_dict +): + """ + Test that pass_through_request still returns a response even if logging raises an Exception + """ + + # Mock the logging handler to raise an error + async def mock_logging_failure(*args, **kwargs): + raise Exception("Logging failed!") + + athropic_request_body["stream"] = True + # Patch only the logging handler + with patch( + "litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler", + new=mock_logging_failure, + ): + request = mock_request( + headers={}, method="POST", request_body=athropic_request_body + ) + response = await pass_through_request( + request=request, + target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages", + custom_headers={}, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert response was returned successfully despite logging failure + assert response.status_code == 200 + + print(vars(response)) + print(dir(response)) + body_iterator = response.body_iterator + async for chunk in body_iterator: + assert chunk + + print("response", response) + print(vars(response)) From f12141be4463f5fab6c41f160b3b45fd5deb691d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 21:44:12 -0800 Subject: [PATCH 4/5] fix anthropic pt logging test --- tests/pass_through_tests/test_anthropic_passthrough.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pass_through_tests/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough.py index b062a025a..6e7839282 100644 --- a/tests/pass_through_tests/test_anthropic_passthrough.py +++ b/tests/pass_through_tests/test_anthropic_passthrough.py @@ -141,7 +141,9 @@ async def test_anthropic_basic_completion_with_headers(): ), "Start time should be before end time" # Metadata assertions - assert log_entry["cache_hit"] == "False", "Cache should be off" + assert ( + str(log_entry["cache_hit"]).lower() != "true" + ), "Cache should be off" assert log_entry["request_tags"] == [ "test-tag-1", "test-tag-2", @@ -251,7 +253,9 @@ async def test_anthropic_streaming_with_headers(): ), "Start time should be before end time" # Metadata assertions - assert log_entry["cache_hit"] == "False", "Cache should be off" + assert ( + str(log_entry["cache_hit"]).lower() != "true" + ), "Cache should be off" assert log_entry["request_tags"] == [ "test-tag-stream-1", "test-tag-stream-2", From 6715905994ccd3420444dedb2a3f26f983da140d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 21:51:00 -0800 Subject: [PATCH 5/5] test_pass_through_request_logging_failure --- .../test_pass_through_unit_tests.py | 82 ++++++++++++++++--- 1 file changed, 70 insertions(+), 12 deletions(-) diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index 287a44e69..c564c14d2 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -202,10 +202,29 @@ async def test_pass_through_request_logging_failure( async def mock_logging_failure(*args, **kwargs): raise Exception("Logging failed!") - # Patch only the logging handler + # Create a mock response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + + # Add mock content + mock_response._content = b'{"mock": "response"}' + + async def mock_aread(): + return mock_response._content + + mock_response.aread = mock_aread + + # Patch both the logging handler and the httpx client with patch( "litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler", new=mock_logging_failure, + ), patch( + "httpx.AsyncClient.send", + return_value=mock_response, + ), patch( + "httpx.AsyncClient.request", + return_value=mock_response, ): request = mock_request( headers={}, method="POST", request_body=athropic_request_body @@ -219,8 +238,14 @@ async def test_pass_through_request_logging_failure( # Assert response was returned successfully despite logging failure assert response.status_code == 200 - print("response", response) - print(vars(response)) + + # Verify we got the mock response content + if hasattr(response, "body"): + content = response.body + else: + content = await response.aread() + + assert content == b'{"mock": "response"}' @pytest.mark.asyncio @@ -235,11 +260,38 @@ async def test_pass_through_request_logging_failure_with_stream( async def mock_logging_failure(*args, **kwargs): raise Exception("Logging failed!") - athropic_request_body["stream"] = True - # Patch only the logging handler + # Create a mock response + mock_response = AsyncMock() + mock_response.status_code = 200 + + # Add headers property to mock response + mock_response.headers = { + "content-type": "application/json", # Not streaming + } + + # Create mock chunks for streaming + mock_chunks = [b'{"chunk": 1}', b'{"chunk": 2}'] + mock_response.body_iterator = AsyncMock() + mock_response.body_iterator.__aiter__.return_value = mock_chunks + + # Add aread method to mock response + mock_response._content = b'{"mock": "response"}' + + async def mock_aread(): + return mock_response._content + + mock_response.aread = mock_aread + + # Patch both the logging handler and the httpx client with patch( "litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler", new=mock_logging_failure, + ), patch( + "httpx.AsyncClient.send", + return_value=mock_response, + ), patch( + "httpx.AsyncClient.request", + return_value=mock_response, ): request = mock_request( headers={}, method="POST", request_body=athropic_request_body @@ -254,11 +306,17 @@ async def test_pass_through_request_logging_failure_with_stream( # Assert response was returned successfully despite logging failure assert response.status_code == 200 - print(vars(response)) - print(dir(response)) - body_iterator = response.body_iterator - async for chunk in body_iterator: - assert chunk + # For non-streaming responses, we can access the content directly + if hasattr(response, "body"): + content = response.body + else: + # For streaming responses, we need to read the chunks + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + content = b"".join(chunks) - print("response", response) - print(vars(response)) + # Verify we got some response content + assert content is not None + if isinstance(content, bytes): + assert len(content) > 0