forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): fix ai21 streaming
This commit is contained in:
parent
57e3044974
commit
4c2ef8ea64
2 changed files with 20 additions and 17 deletions
|
@ -42,8 +42,11 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import Choices, Message
|
||||
from litellm.types.utils import Choices
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import Message
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
ModelResponse,
|
||||
|
@ -2080,13 +2083,13 @@ class AWSEventStreamDecoder:
|
|||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
def converse_chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ConverseTokenUsageBlock] = None
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
|
||||
index = int(chunk_data.get("contentBlockIndex", 0))
|
||||
if "start" in chunk_data:
|
||||
|
@ -2123,9 +2126,13 @@ class AWSEventStreamDecoder:
|
|||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||
is_finished = True
|
||||
elif "usage" in chunk_data:
|
||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=chunk_data.get("inputTokens", 0),
|
||||
completion_tokens=chunk_data.get("outputTokens", 0),
|
||||
total_tokens=chunk_data.get("totalTokens", 0),
|
||||
)
|
||||
|
||||
response = GenericStreamingChunk(
|
||||
response = GChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
|
@ -2137,7 +2144,7 @@ class AWSEventStreamDecoder:
|
|||
except Exception as e:
|
||||
raise Exception("Received streaming error - {}".format(str(e)))
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
|
@ -2180,7 +2187,7 @@ class AWSEventStreamDecoder:
|
|||
elif chunk_data.get("completionReason", None):
|
||||
is_finished = True
|
||||
finish_reason = chunk_data["completionReason"]
|
||||
return GenericStreamingChunk(
|
||||
return GChunk(
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
|
@ -2189,7 +2196,7 @@ class AWSEventStreamDecoder:
|
|||
tool_use=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -2205,7 +2212,7 @@ class AWSEventStreamDecoder:
|
|||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[GenericStreamingChunk]:
|
||||
) -> AsyncIterator[GChunk]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -2245,20 +2252,16 @@ class MockResponseIterator: # for returning ai21 streaming responses
|
|||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk:
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
|
||||
|
||||
try:
|
||||
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
|
||||
processed_chunk = GenericStreamingChunk(
|
||||
processed_chunk = GChunk(
|
||||
text=chunk_data.choices[0].message.content or "", # type: ignore
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
|
||||
usage=ConverseTokenUsageBlock(
|
||||
inputTokens=chunk_usage.prompt_tokens,
|
||||
outputTokens=chunk_usage.completion_tokens,
|
||||
totalTokens=chunk_usage.total_tokens,
|
||||
),
|
||||
usage=chunk_usage, # type: ignore
|
||||
index=0,
|
||||
)
|
||||
return processed_chunk
|
||||
|
|
|
@ -10174,7 +10174,7 @@ class CustomStreamWrapper:
|
|||
processed_chunk = self.finish_reason_handler()
|
||||
if self.stream_options is None: # add usage as hidden param
|
||||
usage = calculate_total_usage(chunks=self.chunks)
|
||||
setattr(processed_chunk, "usage", usage)
|
||||
processed_chunk._hidden_params["usage"] = usage
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue