From 8ce86e51594021599542e3d67f4e643b3aaaf8bf Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 21 Nov 2024 17:25:39 -0800 Subject: [PATCH] working anthropic streaming logging --- .../anthropic_passthrough_logging_handler.py | 100 +++++++- .../streaming_handler.py | 213 +++++------------- 2 files changed, 155 insertions(+), 158 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 4f0ded375..e4c4fb6fc 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx @@ -10,8 +10,18 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import ( get_standard_logging_object_payload, ) +from litellm.llms.anthropic.chat.handler import ( + ModelResponseIterator as AnthropicModelResponseIterator, +) from litellm.llms.anthropic.chat.transformation import AnthropicConfig +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + class AnthropicPassthroughLoggingHandler: @@ -106,3 +116,91 @@ class AnthropicPassthroughLoggingHandler: ) kwargs["standard_logging_object"] = standard_logging_object return kwargs + + @staticmethod + async def _handle_logging_anthropic_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ): + """ + Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + model = request_body.get("model", "") + complete_streaming_response = ( + AnthropicPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." + ) + return + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=complete_streaming_response, + model=model, + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) + await litellm_logging_obj.async_success_handler( + result=complete_streaming_response, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]: + """ + Builds complete response from raw Anthropic chunks + + - Converts str chunks to generic chunks + - Converts generic chunks to litellm chunks (OpenAI format) + - Builds complete response from litellm chunks + """ + anthropic_model_response_iterator = AnthropicModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + litellm_custom_stream_wrapper = litellm.CustomStreamWrapper( + completion_stream=anthropic_model_response_iterator, + model=model, + logging_obj=litellm_logging_obj, + custom_llm_provider="anthropic", + ) + all_openai_chunks = [] + for _chunk_str in all_chunks: + try: + generic_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk( + chunk=_chunk_str + ) + litellm_chunk = litellm_custom_stream_wrapper.chunk_creator( + chunk=generic_chunk + ) + if litellm_chunk is not None: + all_openai_chunks.append(litellm_chunk) + except (StopIteration, StopAsyncIteration) as e: + break + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + return complete_streaming_response diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index d3c1edb7b..9917d88c3 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -24,26 +24,6 @@ from .success_handler import PassThroughEndpointLogging from .types import EndpointType -def get_litellm_chunk( - model_iterator: VertexAIIterator, - custom_stream_wrapper: litellm.utils.CustomStreamWrapper, - chunk_dict: Dict, -) -> Optional[Dict]: - - generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict) - if generic_chunk: - return custom_stream_wrapper.chunk_creator(chunk=generic_chunk) - return None - - -def get_iterator_class_from_endpoint_type( - endpoint_type: EndpointType, -) -> Optional[type]: - if endpoint_type == EndpointType.VERTEX_AI: - return VertexAIIterator - return None - - async def chunk_processor( response: httpx.Response, request_body: Optional[dict], @@ -52,156 +32,75 @@ async def chunk_processor( start_time: datetime, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, -) -> AsyncIterable[Union[str, bytes]]: - request_body = request_body or {} - iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type) - aiter_bytes = response.aiter_bytes() - aiter_lines = response.aiter_lines() - all_chunks = [] - if iteratorClass is None: - # Generic endpoint - litellm does not do any tracking / logging for this - async for chunk in aiter_lines: - yield chunk - elif endpoint_type == EndpointType.ANTHROPIC: - anthropic_iterator = AnthropicIterator( - sync_stream=False, - streaming_response=aiter_lines, - json_mode=False, +): + """ + - Yields chunks from the response + - Collect non-empty chunks for post-processing (logging) + """ + collected_chunks: List[str] = [] # List to store all chunks + try: + async for chunk in response.aiter_lines(): + verbose_proxy_logger.debug(f"Processing chunk: {chunk}") + if not chunk: + continue + + # Handle SSE format - pass through the raw SSE format + chunk = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + + # Store the chunk for post-processing + if chunk.strip(): # Only store non-empty chunks + collected_chunks.append(chunk) + yield f"{chunk}\n" + + # After all chunks are processed, handle post-processing + end_time = datetime.now() + + await _route_streaming_logging_to_handler( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=collected_chunks, + end_time=end_time, ) - custom_stream_wrapper = litellm.utils.CustomStreamWrapper( - completion_stream=aiter_bytes, - model=None, - logging_obj=litellm_logging_obj, - custom_llm_provider="anthropic", - ) - async for chunk in aiter_lines: - try: - generic_chunk = anthropic_iterator.convert_str_chunk_to_generic_chunk( - chunk - ) - litellm_chunk = custom_stream_wrapper.chunk_creator(chunk=generic_chunk) - if litellm_chunk: - all_chunks.append(litellm_chunk) - except Exception as e: - verbose_proxy_logger.error( - f"Error parsing chunk: {e},\nReceived chunk: {chunk}" - ) - finally: - yield chunk - else: - # known streaming endpoint - litellm will do tracking / logging for this - model_iterator = iteratorClass( - sync_stream=False, streaming_response=aiter_bytes - ) - custom_stream_wrapper = litellm.utils.CustomStreamWrapper( - completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj - ) - buffer = b"" - async for chunk in aiter_bytes: - buffer += chunk - try: - _decoded_chunk = chunk.decode("utf-8") - _chunk_dict = json.loads(_decoded_chunk) - litellm_chunk = get_litellm_chunk( - model_iterator, custom_stream_wrapper, _chunk_dict - ) - if litellm_chunk: - all_chunks.append(litellm_chunk) - except json.JSONDecodeError: - pass - finally: - yield chunk # Yield the original bytes - # Process any remaining data in the buffer - if buffer: - try: - _chunk_dict = json.loads(buffer.decode("utf-8")) - - if isinstance(_chunk_dict, list): - for _chunk in _chunk_dict: - litellm_chunk = get_litellm_chunk( - model_iterator, custom_stream_wrapper, _chunk - ) - if litellm_chunk: - all_chunks.append(litellm_chunk) - elif isinstance(_chunk_dict, dict): - litellm_chunk = get_litellm_chunk( - model_iterator, custom_stream_wrapper, _chunk_dict - ) - if litellm_chunk: - all_chunks.append(litellm_chunk) - except json.JSONDecodeError: - pass - - await _handle_logging_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - end_time=datetime.now(), - all_chunks=all_chunks, - ) + except Exception as e: + verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") + raise -async def _handle_logging_collected_chunks( +async def _route_streaming_logging_to_handler( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, request_body: dict, endpoint_type: EndpointType, start_time: datetime, - all_chunks: List[Dict], + all_chunks: List[str], end_time: datetime, ): """ - Build the complete response and handle the logging + Route the logging for the collected chunks to the appropriate handler - This gets triggered once all the chunks are collected + Supported endpoint types: + - Anthropic + - Vertex AI """ - try: - complete_streaming_response: Optional[ - Union[litellm.ModelResponse, litellm.TextCompletionResponse] - ] = litellm.stream_chunk_builder(chunks=all_chunks) - if complete_streaming_response is None: - complete_streaming_response = litellm.ModelResponse() - end_time = datetime.now() - verbose_proxy_logger.debug( - "complete_streaming_response %s", complete_streaming_response + if endpoint_type == EndpointType.ANTHROPIC: + await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, ) - kwargs = {} - - if passthrough_success_handler_obj.is_vertex_route(url_route): - _model = passthrough_success_handler_obj.extract_model_from_url(url_route) - complete_streaming_response.model = _model - litellm_logging_obj.model = _model - litellm_logging_obj.model_call_details["model"] = _model - elif endpoint_type == EndpointType.ANTHROPIC: - model = request_body.get("model", "") - kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( - litellm_model_response=complete_streaming_response, - model=model, - kwargs=litellm_logging_obj.model_call_details, - start_time=start_time, - end_time=end_time, - logging_obj=litellm_logging_obj, - ) - litellm_logging_obj.model = model - complete_streaming_response.model = model - litellm_logging_obj.model_call_details["model"] = model - # Remove start_time and end_time from kwargs since they'll be passed explicitly - kwargs.pop("start_time", None) - kwargs.pop("end_time", None) - litellm_logging_obj.model_call_details.update(kwargs) - - asyncio.create_task( - litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - **kwargs, - ) - ) - except Exception as e: - verbose_proxy_logger.error(f"Error handling logging collected chunks: {e}") + elif endpoint_type == EndpointType.VERTEX_AI: + pass + elif endpoint_type == EndpointType.GENERIC: + # No logging is supported for generic streaming endpoints + pass