add support for anthropic streaming usage tracking

This commit is contained in:
Ishaan Jaff 2024-11-20 19:25:05 -08:00
parent c977677c93
commit 9dc67cfebd
3 changed files with 123 additions and 8 deletions

View file

@ -779,3 +779,24 @@ class ModelResponseIterator:
raise StopAsyncIteration raise StopAsyncIteration
except ValueError as e: except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
if str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
return self.chunk_parser(chunk=data_json)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)

View file

@ -4,7 +4,7 @@ import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional, Union
import httpx import httpx
from fastapi import ( from fastapi import (
@ -310,13 +310,15 @@ def get_endpoint_type(url: str) -> EndpointType:
async def stream_response( async def stream_response(
response: httpx.Response, response: httpx.Response,
request_body: Optional[dict],
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType, endpoint_type: EndpointType,
start_time: datetime, start_time: datetime,
url: str, url: str,
) -> AsyncIterable[bytes]: ) -> AsyncIterable[Union[str, bytes]]:
async for chunk in chunk_processor( async for chunk in chunk_processor(
response.aiter_bytes(), response=response,
request_body=request_body,
litellm_logging_obj=logging_obj, litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type, endpoint_type=endpoint_type,
start_time=start_time, start_time=start_time,
@ -468,6 +470,7 @@ async def pass_through_request( # noqa: PLR0915
return StreamingResponse( return StreamingResponse(
stream_response( stream_response(
response=response, response=response,
request_body=_parsed_body,
logging_obj=logging_obj, logging_obj=logging_obj,
endpoint_type=endpoint_type, endpoint_type=endpoint_type,
start_time=start_time, start_time=start_time,
@ -506,6 +509,7 @@ async def pass_through_request( # noqa: PLR0915
return StreamingResponse( return StreamingResponse(
stream_response( stream_response(
response=response, response=response,
request_body=_parsed_body,
logging_obj=logging_obj, logging_obj=logging_obj,
endpoint_type=endpoint_type, endpoint_type=endpoint_type,
start_time=start_time, start_time=start_time,

View file

@ -4,13 +4,22 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import AsyncIterable, Dict, List, Optional, Union from typing import AsyncIterable, Dict, List, Optional, Union
import httpx
import litellm import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicIterator,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexAIIterator, ModelResponseIterator as VertexAIIterator,
) )
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging from .success_handler import PassThroughEndpointLogging
from .types import EndpointType from .types import EndpointType
@ -36,19 +45,49 @@ def get_iterator_class_from_endpoint_type(
async def chunk_processor( async def chunk_processor(
aiter_bytes: AsyncIterable[bytes], response: httpx.Response,
request_body: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType, endpoint_type: EndpointType,
start_time: datetime, start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging, passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str, url_route: str,
) -> AsyncIterable[bytes]: ) -> AsyncIterable[Union[str, bytes]]:
request_body = request_body or {}
iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type) iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type)
aiter_bytes = response.aiter_bytes()
aiter_lines = response.aiter_lines()
all_chunks = []
if iteratorClass is None: if iteratorClass is None:
# Generic endpoint - litellm does not do any tracking / logging for this # Generic endpoint - litellm does not do any tracking / logging for this
async for chunk in aiter_bytes: async for chunk in aiter_lines:
yield chunk yield chunk
elif endpoint_type == EndpointType.ANTHROPIC:
anthropic_iterator = AnthropicIterator(
sync_stream=False,
streaming_response=aiter_lines,
json_mode=False,
)
custom_stream_wrapper = litellm.utils.CustomStreamWrapper(
completion_stream=aiter_bytes,
model=None,
logging_obj=litellm_logging_obj,
custom_llm_provider="anthropic",
)
async for chunk in aiter_lines:
try:
generic_chunk = anthropic_iterator.convert_str_chunk_to_generic_chunk(
chunk
)
litellm_chunk = custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except Exception as e:
verbose_proxy_logger.error(
f"Error parsing chunk: {e},\nReceived chunk: {chunk}"
)
finally:
yield chunk
else: else:
# known streaming endpoint - litellm will do tracking / logging for this # known streaming endpoint - litellm will do tracking / logging for this
model_iterator = iteratorClass( model_iterator = iteratorClass(
@ -58,7 +97,6 @@ async def chunk_processor(
completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj
) )
buffer = b"" buffer = b""
all_chunks = []
async for chunk in aiter_bytes: async for chunk in aiter_bytes:
buffer += chunk buffer += chunk
try: try:
@ -95,23 +133,75 @@ async def chunk_processor(
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
await _handle_logging_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
end_time=datetime.now(),
all_chunks=all_chunks,
)
async def _handle_logging_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[Dict],
end_time: datetime,
):
"""
Build the complete response and handle the logging
This gets triggered once all the chunks are collected
"""
try:
complete_streaming_response: Optional[ complete_streaming_response: Optional[
Union[litellm.ModelResponse, litellm.TextCompletionResponse] Union[litellm.ModelResponse, litellm.TextCompletionResponse]
] = litellm.stream_chunk_builder(chunks=all_chunks) ] = litellm.stream_chunk_builder(chunks=all_chunks)
if complete_streaming_response is None: if complete_streaming_response is None:
complete_streaming_response = litellm.ModelResponse() complete_streaming_response = litellm.ModelResponse()
end_time = datetime.now() end_time = datetime.now()
verbose_proxy_logger.debug(
"complete_streaming_response %s", complete_streaming_response
)
kwargs = {}
if passthrough_success_handler_obj.is_vertex_route(url_route): if passthrough_success_handler_obj.is_vertex_route(url_route):
_model = passthrough_success_handler_obj.extract_model_from_url(url_route) _model = passthrough_success_handler_obj.extract_model_from_url(url_route)
complete_streaming_response.model = _model complete_streaming_response.model = _model
litellm_logging_obj.model = _model litellm_logging_obj.model = _model
litellm_logging_obj.model_call_details["model"] = _model litellm_logging_obj.model_call_details["model"] = _model
elif endpoint_type == EndpointType.ANTHROPIC:
model = request_body.get("model", "")
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=litellm_logging_obj.model_call_details,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
litellm_logging_obj.model = model
complete_streaming_response.model = model
litellm_logging_obj.model_call_details["model"] = model
# Remove start_time and end_time from kwargs since they'll be passed explicitly
kwargs.pop("start_time", None)
kwargs.pop("end_time", None)
litellm_logging_obj.model_call_details.update(kwargs)
asyncio.create_task( asyncio.create_task(
litellm_logging_obj.async_success_handler( litellm_logging_obj.async_success_handler(
result=complete_streaming_response, result=complete_streaming_response,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
**kwargs,
) )
) )
except Exception as e:
verbose_proxy_logger.error(f"Error handling logging collected chunks: {e}")