From 2c0551f0a8de5980f53258a3cdd3956b48f0e9a2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 23 Nov 2024 16:58:25 -0800 Subject: [PATCH] perf improvement - use tasks for async logging pass through responses --- .../pass_through_endpoints.py | 23 ++++++------ .../streaming_handler.py | 36 ++++++++++--------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 0fd174440..5fa6c0457 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -528,16 +528,19 @@ async def pass_through_request( # noqa: PLR0915 response_body: Optional[dict] = get_response_body(response) passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() - await pass_through_endpoint_logging.pass_through_async_success_handler( - httpx_response=response, - response_body=response_body, - url_route=str(url), - result="", - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - cache_hit=False, - **kwargs, + + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + **kwargs, + ) ) return Response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index dc6aae3af..15bc7127e 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -58,22 +58,24 @@ class PassThroughStreamingHandler: # After all chunks are processed, handle post-processing end_time = datetime.now() - await PassThroughStreamingHandler._route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - raw_bytes=raw_bytes, - end_time=end_time, + asyncio.create_task( + PassThroughStreamingHandler.handle_logging_collected_stream_response( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) ) except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") raise @staticmethod - async def _route_streaming_logging_to_handler( + async def handle_logging_collected_stream_response( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -108,9 +110,9 @@ class PassThroughStreamingHandler: all_chunks=all_chunks, end_time=end_time, ) - standard_logging_response_object = anthropic_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: vertex_passthrough_logging_handler_result = ( @@ -125,9 +127,9 @@ class PassThroughStreamingHandler: end_time=end_time, ) ) - standard_logging_response_object = vertex_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) kwargs = vertex_passthrough_logging_handler_result["kwargs"] if standard_logging_response_object is None: @@ -168,4 +170,4 @@ class PassThroughStreamingHandler: # Split by newlines and filter out empty lines lines = [line.strip() for line in combined_str.split("\n") if line.strip()] - return lines \ No newline at end of file + return lines