From 04c9284da43982e692975f47e9b1ad3c126ad464 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 14:19:28 -0800 Subject: [PATCH] use PassThroughStreamingHandler --- .../streaming_handler.py | 164 ++++++++++-------- 1 file changed, 87 insertions(+), 77 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index a88ad34d3..522319aaa 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -27,97 +27,107 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -async def chunk_processor( - response: httpx.Response, - request_body: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - endpoint_type: EndpointType, - start_time: datetime, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, -): - """ - - Yields chunks from the response - - Collect non-empty chunks for post-processing (logging) - """ - try: - if endpoint_type == EndpointType.VERTEX_AI: +class PassThroughStreamingHandler: + + @staticmethod + async def chunk_processor( + response: httpx.Response, + request_body: Optional[dict], + litellm_logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + ): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + try: + raw_bytes: List[bytes] = [] async for chunk in response.aiter_bytes(): + raw_bytes.append(chunk) yield chunk - else: - collected_chunks: List[str] = [] # List to store all chunks - async for chunk in response.aiter_lines(): - verbose_proxy_logger.debug(f"Processing chunk: {chunk}") - if not chunk: - continue - - # Handle SSE format - pass through the raw SSE format - if isinstance(chunk, bytes): - chunk = chunk.decode("utf-8") - - # Store the chunk for post-processing - if chunk.strip(): # Only store non-empty chunks - collected_chunks.append(chunk) - yield f"{chunk}\n" # After all chunks are processed, handle post-processing end_time = datetime.now() - await _route_streaming_logging_to_handler( + await PassThroughStreamingHandler._route_streaming_logging_to_handler( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, request_body=request_body or {}, endpoint_type=endpoint_type, start_time=start_time, - all_chunks=collected_chunks, + raw_bytes=raw_bytes, end_time=end_time, ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise - except Exception as e: - verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") - raise + @staticmethod + async def _route_streaming_logging_to_handler( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + raw_bytes: List[bytes], + end_time: datetime, + ): + """ + Route the logging for the collected chunks to the appropriate handler - -async def _route_streaming_logging_to_handler( - litellm_logging_obj: LiteLLMLoggingObj, - passthrough_success_handler_obj: PassThroughEndpointLogging, - url_route: str, - request_body: dict, - endpoint_type: EndpointType, - start_time: datetime, - all_chunks: List[str], - end_time: datetime, -): - """ - Route the logging for the collected chunks to the appropriate handler - - Supported endpoint types: - - Anthropic - - Vertex AI - """ - if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + Supported endpoint types: + - Anthropic + - Vertex AI + """ + all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( + raw_bytes ) - elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, - ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.VERTEX_AI: + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass + + @staticmethod + def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: + """ + Converts a list of raw bytes into a list of string lines, similar to aiter_lines() + + Args: + raw_bytes: List of bytes chunks from aiter.bytes() + + Returns: + List of string lines, with each line being a complete data: {} chunk + """ + # Combine all bytes and decode to string + combined_str = b"".join(raw_bytes).decode("utf-8") + + # Split by newlines and filter out empty lines + lines = [line.strip() for line in combined_str.split("\n") if line.strip()] + + return lines