working anthropic streaming logging

This commit is contained in:
Ishaan Jaff 2024-11-21 17:25:39 -08:00
parent 0f7caa1cdb
commit 8ce86e5159
2 changed files with 155 additions and 158 deletions

View file

@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
@ -10,8 +10,18 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class AnthropicPassthroughLoggingHandler:
@ -106,3 +116,91 @@ class AnthropicPassthroughLoggingHandler:
)
kwargs["standard_logging_object"] = standard_logging_object
return kwargs
@staticmethod
async def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
complete_streaming_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
completion_stream=anthropic_model_response_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="anthropic",
)
all_openai_chunks = []
for _chunk_str in all_chunks:
try:
generic_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=_chunk_str
)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
except (StopIteration, StopAsyncIteration) as e:
break
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response

View file

@ -24,26 +24,6 @@ from .success_handler import PassThroughEndpointLogging
from .types import EndpointType
def get_litellm_chunk(
model_iterator: VertexAIIterator,
custom_stream_wrapper: litellm.utils.CustomStreamWrapper,
chunk_dict: Dict,
) -> Optional[Dict]:
generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict)
if generic_chunk:
return custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
return None
def get_iterator_class_from_endpoint_type(
endpoint_type: EndpointType,
) -> Optional[type]:
if endpoint_type == EndpointType.VERTEX_AI:
return VertexAIIterator
return None
async def chunk_processor(
response: httpx.Response,
request_body: Optional[dict],
@ -52,156 +32,75 @@ async def chunk_processor(
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
) -> AsyncIterable[Union[str, bytes]]:
request_body = request_body or {}
iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type)
aiter_bytes = response.aiter_bytes()
aiter_lines = response.aiter_lines()
all_chunks = []
if iteratorClass is None:
# Generic endpoint - litellm does not do any tracking / logging for this
async for chunk in aiter_lines:
yield chunk
elif endpoint_type == EndpointType.ANTHROPIC:
anthropic_iterator = AnthropicIterator(
sync_stream=False,
streaming_response=aiter_lines,
json_mode=False,
):
"""
- Yields chunks from the response
- Collect non-empty chunks for post-processing (logging)
"""
collected_chunks: List[str] = [] # List to store all chunks
try:
async for chunk in response.aiter_lines():
verbose_proxy_logger.debug(f"Processing chunk: {chunk}")
if not chunk:
continue
# Handle SSE format - pass through the raw SSE format
chunk = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
# Store the chunk for post-processing
if chunk.strip(): # Only store non-empty chunks
collected_chunks.append(chunk)
yield f"{chunk}\n"
# After all chunks are processed, handle post-processing
end_time = datetime.now()
await _route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=collected_chunks,
end_time=end_time,
)
custom_stream_wrapper = litellm.utils.CustomStreamWrapper(
completion_stream=aiter_bytes,
model=None,
logging_obj=litellm_logging_obj,
custom_llm_provider="anthropic",
)
async for chunk in aiter_lines:
try:
generic_chunk = anthropic_iterator.convert_str_chunk_to_generic_chunk(
chunk
)
litellm_chunk = custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except Exception as e:
verbose_proxy_logger.error(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
)
finally:
yield chunk
else:
# known streaming endpoint - litellm will do tracking / logging for this
model_iterator = iteratorClass(
sync_stream=False, streaming_response=aiter_bytes
)
custom_stream_wrapper = litellm.utils.CustomStreamWrapper(
completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj
)
buffer = b""
async for chunk in aiter_bytes:
buffer += chunk
try:
_decoded_chunk = chunk.decode("utf-8")
_chunk_dict = json.loads(_decoded_chunk)
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk_dict
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except json.JSONDecodeError:
pass
finally:
yield chunk # Yield the original bytes
# Process any remaining data in the buffer
if buffer:
try:
_chunk_dict = json.loads(buffer.decode("utf-8"))
if isinstance(_chunk_dict, list):
for _chunk in _chunk_dict:
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
elif isinstance(_chunk_dict, dict):
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk_dict
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except json.JSONDecodeError:
pass
await _handle_logging_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
end_time=datetime.now(),
all_chunks=all_chunks,
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
async def _handle_logging_collected_chunks(
async def _route_streaming_logging_to_handler(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[Dict],
all_chunks: List[str],
end_time: datetime,
):
"""
Build the complete response and handle the logging
Route the logging for the collected chunks to the appropriate handler
This gets triggered once all the chunks are collected
Supported endpoint types:
- Anthropic
- Vertex AI
"""
try:
complete_streaming_response: Optional[
Union[litellm.ModelResponse, litellm.TextCompletionResponse]
] = litellm.stream_chunk_builder(chunks=all_chunks)
if complete_streaming_response is None:
complete_streaming_response = litellm.ModelResponse()
end_time = datetime.now()
verbose_proxy_logger.debug(
"complete_streaming_response %s", complete_streaming_response
if endpoint_type == EndpointType.ANTHROPIC:
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
kwargs = {}
if passthrough_success_handler_obj.is_vertex_route(url_route):
_model = passthrough_success_handler_obj.extract_model_from_url(url_route)
complete_streaming_response.model = _model
litellm_logging_obj.model = _model
litellm_logging_obj.model_call_details["model"] = _model
elif endpoint_type == EndpointType.ANTHROPIC:
model = request_body.get("model", "")
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=litellm_logging_obj.model_call_details,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
litellm_logging_obj.model = model
complete_streaming_response.model = model
litellm_logging_obj.model_call_details["model"] = model
# Remove start_time and end_time from kwargs since they'll be passed explicitly
kwargs.pop("start_time", None)
kwargs.pop("end_time", None)
litellm_logging_obj.model_call_details.update(kwargs)
asyncio.create_task(
litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
**kwargs,
)
)
except Exception as e:
verbose_proxy_logger.error(f"Error handling logging collected chunks: {e}")
elif endpoint_type == EndpointType.VERTEX_AI:
pass
elif endpoint_type == EndpointType.GENERIC:
# No logging is supported for generic streaming endpoints
pass