fix(sagemaker.py): fix streaming logic

This commit is contained in:
Krrish Dholakia 2024-08-27 08:10:47 -07:00
parent c7bbfef846
commit bf81b484c6

View file

@ -27,6 +27,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
@ -35,8 +36,8 @@ from litellm.utils import (
get_secret,
)
from .base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
from ..base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None
@ -285,6 +286,10 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name,
)
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
@ -303,6 +308,7 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers,
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
)
## Load Config
@ -937,12 +943,21 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -958,6 +973,7 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
@ -965,9 +981,12 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -988,6 +1007,9 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
@ -999,16 +1021,20 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
except json.JSONDecodeError as e:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GChunk]:
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -1030,6 +1056,9 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
@ -1041,12 +1070,16 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()