diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 86b1117ab..cad95b39b 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -779,3 +779,24 @@ class ModelResponseIterator: raise StopAsyncIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + if str_line.startswith("data:"): + data_json = json.loads(str_line[5:]) + return self.chunk_parser(chunk=data_json) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 6c9a93849..7467266b8 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -4,7 +4,7 @@ import json import traceback from base64 import b64encode from datetime import datetime -from typing import AsyncIterable, List, Optional +from typing import AsyncIterable, List, Optional, Union import httpx from fastapi import ( @@ -310,13 +310,15 @@ def get_endpoint_type(url: str) -> EndpointType: async def stream_response( response: httpx.Response, + request_body: Optional[dict], logging_obj: LiteLLMLoggingObj, endpoint_type: EndpointType, start_time: datetime, url: str, -) -> AsyncIterable[bytes]: +) -> AsyncIterable[Union[str, bytes]]: async for chunk in chunk_processor( - response.aiter_bytes(), + response=response, + request_body=request_body, litellm_logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, @@ -468,6 +470,7 @@ async def pass_through_request( # noqa: PLR0915 return StreamingResponse( stream_response( response=response, + request_body=_parsed_body, logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, @@ -506,6 +509,7 @@ async def pass_through_request( # noqa: PLR0915 return StreamingResponse( stream_response( response=response, + request_body=_parsed_body, logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index b7faa21e4..d3c1edb7b 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -4,13 +4,22 @@ from datetime import datetime from enum import Enum from typing import AsyncIterable, Dict, List, Optional, Union +import httpx + import litellm +from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.anthropic.chat.handler import ( + ModelResponseIterator as AnthropicIterator, +) from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexAIIterator, ) from litellm.types.utils import GenericStreamingChunk +from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, +) from .success_handler import PassThroughEndpointLogging from .types import EndpointType @@ -36,19 +45,49 @@ def get_iterator_class_from_endpoint_type( async def chunk_processor( - aiter_bytes: AsyncIterable[bytes], + response: httpx.Response, + request_body: Optional[dict], litellm_logging_obj: LiteLLMLoggingObj, endpoint_type: EndpointType, start_time: datetime, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, -) -> AsyncIterable[bytes]: - +) -> AsyncIterable[Union[str, bytes]]: + request_body = request_body or {} iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type) + aiter_bytes = response.aiter_bytes() + aiter_lines = response.aiter_lines() + all_chunks = [] if iteratorClass is None: # Generic endpoint - litellm does not do any tracking / logging for this - async for chunk in aiter_bytes: + async for chunk in aiter_lines: yield chunk + elif endpoint_type == EndpointType.ANTHROPIC: + anthropic_iterator = AnthropicIterator( + sync_stream=False, + streaming_response=aiter_lines, + json_mode=False, + ) + custom_stream_wrapper = litellm.utils.CustomStreamWrapper( + completion_stream=aiter_bytes, + model=None, + logging_obj=litellm_logging_obj, + custom_llm_provider="anthropic", + ) + async for chunk in aiter_lines: + try: + generic_chunk = anthropic_iterator.convert_str_chunk_to_generic_chunk( + chunk + ) + litellm_chunk = custom_stream_wrapper.chunk_creator(chunk=generic_chunk) + if litellm_chunk: + all_chunks.append(litellm_chunk) + except Exception as e: + verbose_proxy_logger.error( + f"Error parsing chunk: {e},\nReceived chunk: {chunk}" + ) + finally: + yield chunk else: # known streaming endpoint - litellm will do tracking / logging for this model_iterator = iteratorClass( @@ -58,7 +97,6 @@ async def chunk_processor( completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj ) buffer = b"" - all_chunks = [] async for chunk in aiter_bytes: buffer += chunk try: @@ -95,23 +133,75 @@ async def chunk_processor( except json.JSONDecodeError: pass + await _handle_logging_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + end_time=datetime.now(), + all_chunks=all_chunks, + ) + + +async def _handle_logging_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[Dict], + end_time: datetime, +): + """ + Build the complete response and handle the logging + + This gets triggered once all the chunks are collected + """ + try: complete_streaming_response: Optional[ Union[litellm.ModelResponse, litellm.TextCompletionResponse] ] = litellm.stream_chunk_builder(chunks=all_chunks) if complete_streaming_response is None: complete_streaming_response = litellm.ModelResponse() end_time = datetime.now() + verbose_proxy_logger.debug( + "complete_streaming_response %s", complete_streaming_response + ) + kwargs = {} if passthrough_success_handler_obj.is_vertex_route(url_route): _model = passthrough_success_handler_obj.extract_model_from_url(url_route) complete_streaming_response.model = _model litellm_logging_obj.model = _model litellm_logging_obj.model_call_details["model"] = _model + elif endpoint_type == EndpointType.ANTHROPIC: + model = request_body.get("model", "") + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=complete_streaming_response, + model=model, + kwargs=litellm_logging_obj.model_call_details, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) + litellm_logging_obj.model = model + complete_streaming_response.model = model + litellm_logging_obj.model_call_details["model"] = model + # Remove start_time and end_time from kwargs since they'll be passed explicitly + kwargs.pop("start_time", None) + kwargs.pop("end_time", None) + litellm_logging_obj.model_call_details.update(kwargs) asyncio.create_task( litellm_logging_obj.async_success_handler( result=complete_streaming_response, start_time=start_time, end_time=end_time, + **kwargs, ) ) + except Exception as e: + verbose_proxy_logger.error(f"Error handling logging collected chunks: {e}")