forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): fix ai21 streaming
This commit is contained in:
parent
57e3044974
commit
4c2ef8ea64
2 changed files with 20 additions and 17 deletions
|
@ -42,8 +42,11 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, Message
|
from litellm.types.utils import Choices
|
||||||
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
from litellm.types.utils import Message
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
@ -2080,13 +2083,13 @@ class AWSEventStreamDecoder:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def converse_chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
try:
|
try:
|
||||||
text = ""
|
text = ""
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ConverseTokenUsageBlock] = None
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
|
||||||
index = int(chunk_data.get("contentBlockIndex", 0))
|
index = int(chunk_data.get("contentBlockIndex", 0))
|
||||||
if "start" in chunk_data:
|
if "start" in chunk_data:
|
||||||
|
@ -2123,9 +2126,13 @@ class AWSEventStreamDecoder:
|
||||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
is_finished = True
|
is_finished = True
|
||||||
elif "usage" in chunk_data:
|
elif "usage" in chunk_data:
|
||||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=chunk_data.get("inputTokens", 0),
|
||||||
|
completion_tokens=chunk_data.get("outputTokens", 0),
|
||||||
|
total_tokens=chunk_data.get("totalTokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
response = GenericStreamingChunk(
|
response = GChunk(
|
||||||
text=text,
|
text=text,
|
||||||
tool_use=tool_use,
|
tool_use=tool_use,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
|
@ -2137,7 +2144,7 @@ class AWSEventStreamDecoder:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("Received streaming error - {}".format(str(e)))
|
raise Exception("Received streaming error - {}".format(str(e)))
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
|
@ -2180,7 +2187,7 @@ class AWSEventStreamDecoder:
|
||||||
elif chunk_data.get("completionReason", None):
|
elif chunk_data.get("completionReason", None):
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = chunk_data["completionReason"]
|
finish_reason = chunk_data["completionReason"]
|
||||||
return GenericStreamingChunk(
|
return GChunk(
|
||||||
text=text,
|
text=text,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
@ -2189,7 +2196,7 @@ class AWSEventStreamDecoder:
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -2205,7 +2212,7 @@ class AWSEventStreamDecoder:
|
||||||
|
|
||||||
async def aiter_bytes(
|
async def aiter_bytes(
|
||||||
self, iterator: AsyncIterator[bytes]
|
self, iterator: AsyncIterator[bytes]
|
||||||
) -> AsyncIterator[GenericStreamingChunk]:
|
) -> AsyncIterator[GChunk]:
|
||||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -2245,20 +2252,16 @@ class MockResponseIterator: # for returning ai21 streaming responses
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
|
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
|
||||||
processed_chunk = GenericStreamingChunk(
|
processed_chunk = GChunk(
|
||||||
text=chunk_data.choices[0].message.content or "", # type: ignore
|
text=chunk_data.choices[0].message.content or "", # type: ignore
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
is_finished=True,
|
is_finished=True,
|
||||||
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
|
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
|
||||||
usage=ConverseTokenUsageBlock(
|
usage=chunk_usage, # type: ignore
|
||||||
inputTokens=chunk_usage.prompt_tokens,
|
|
||||||
outputTokens=chunk_usage.completion_tokens,
|
|
||||||
totalTokens=chunk_usage.total_tokens,
|
|
||||||
),
|
|
||||||
index=0,
|
index=0,
|
||||||
)
|
)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
|
|
|
@ -10174,7 +10174,7 @@ class CustomStreamWrapper:
|
||||||
processed_chunk = self.finish_reason_handler()
|
processed_chunk = self.finish_reason_handler()
|
||||||
if self.stream_options is None: # add usage as hidden param
|
if self.stream_options is None: # add usage as hidden param
|
||||||
usage = calculate_total_usage(chunks=self.chunks)
|
usage = calculate_total_usage(chunks=self.chunks)
|
||||||
setattr(processed_chunk, "usage", usage)
|
processed_chunk._hidden_params["usage"] = usage
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.success_handler,
|
target=self.logging_obj.success_handler,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue