fix(bedrock_httpx.py): fix ai21 streaming

This commit is contained in:
Krrish Dholakia 2024-08-01 22:03:24 -07:00
parent 57e3044974
commit 4c2ef8ea64
2 changed files with 20 additions and 17 deletions

View file

@ -42,8 +42,11 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
) )
from litellm.types.utils import Choices, Message from litellm.types.utils import Choices
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import Message
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
ModelResponse, ModelResponse,
@ -2080,13 +2083,13 @@ class AWSEventStreamDecoder:
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def converse_chunk_parser(self, chunk_data: dict) -> GChunk:
try: try:
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None usage: Optional[ChatCompletionUsageBlock] = None
index = int(chunk_data.get("contentBlockIndex", 0)) index = int(chunk_data.get("contentBlockIndex", 0))
if "start" in chunk_data: if "start" in chunk_data:
@ -2123,9 +2126,13 @@ class AWSEventStreamDecoder:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
is_finished = True is_finished = True
elif "usage" in chunk_data: elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore usage = ChatCompletionUsageBlock(
prompt_tokens=chunk_data.get("inputTokens", 0),
completion_tokens=chunk_data.get("outputTokens", 0),
total_tokens=chunk_data.get("totalTokens", 0),
)
response = GenericStreamingChunk( response = GChunk(
text=text, text=text,
tool_use=tool_use, tool_use=tool_use,
is_finished=is_finished, is_finished=is_finished,
@ -2137,7 +2144,7 @@ class AWSEventStreamDecoder:
except Exception as e: except Exception as e:
raise Exception("Received streaming error - {}".format(str(e))) raise Exception("Received streaming error - {}".format(str(e)))
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GChunk:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -2180,7 +2187,7 @@ class AWSEventStreamDecoder:
elif chunk_data.get("completionReason", None): elif chunk_data.get("completionReason", None):
is_finished = True is_finished = True
finish_reason = chunk_data["completionReason"] finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk( return GChunk(
text=text, text=text,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
@ -2189,7 +2196,7 @@ class AWSEventStreamDecoder:
tool_use=None, tool_use=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" """Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -2205,7 +2212,7 @@ class AWSEventStreamDecoder:
async def aiter_bytes( async def aiter_bytes(
self, iterator: AsyncIterator[bytes] self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GenericStreamingChunk]: ) -> AsyncIterator[GChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -2245,20 +2252,16 @@ class MockResponseIterator: # for returning ai21 streaming responses
def __iter__(self): def __iter__(self):
return self return self
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
try: try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage") chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
processed_chunk = GenericStreamingChunk( processed_chunk = GChunk(
text=chunk_data.choices[0].message.content or "", # type: ignore text=chunk_data.choices[0].message.content or "", # type: ignore
tool_use=None, tool_use=None,
is_finished=True, is_finished=True,
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
usage=ConverseTokenUsageBlock( usage=chunk_usage, # type: ignore
inputTokens=chunk_usage.prompt_tokens,
outputTokens=chunk_usage.completion_tokens,
totalTokens=chunk_usage.total_tokens,
),
index=0, index=0,
) )
return processed_chunk return processed_chunk

View file

@ -10174,7 +10174,7 @@ class CustomStreamWrapper:
processed_chunk = self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
if self.stream_options is None: # add usage as hidden param if self.stream_options is None: # add usage as hidden param
usage = calculate_total_usage(chunks=self.chunks) usage = calculate_total_usage(chunks=self.chunks)
setattr(processed_chunk, "usage", usage) processed_chunk._hidden_params["usage"] = usage
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,