From bf81b484c62d57bc11b036ea8749eeeb04da5c9d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 27 Aug 2024 08:10:47 -0700 Subject: [PATCH] fix(sagemaker.py): fix streaming logic --- litellm/llms/sagemaker/sagemaker.py | 53 +++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py index c83b80bcb..5e7776689 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -27,6 +27,7 @@ from litellm.types.llms.openai import ( ChatCompletionUsageBlock, ) from litellm.types.utils import GenericStreamingChunk as GChunk +from litellm.types.utils import StreamingChatCompletionChunk from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -35,8 +36,8 @@ from litellm.utils import ( get_secret, ) -from .base_aws_llm import BaseAWSLLM -from .prompt_templates.factory import custom_prompt, prompt_factory +from ..base_aws_llm import BaseAWSLLM +from ..prompt_templates.factory import custom_prompt, prompt_factory _response_stream_shape_cache = None @@ -285,6 +286,10 @@ class SagemakerLLM(BaseAWSLLM): aws_region_name=aws_region_name, ) + custom_stream_decoder = AWSEventStreamDecoder( + model="", is_messages_api=True + ) + return openai_like_chat_completions.completion( model=model, messages=messages, @@ -303,6 +308,7 @@ class SagemakerLLM(BaseAWSLLM): headers=prepared_request.headers, custom_endpoint=True, custom_llm_provider="sagemaker_chat", + streaming_decoder=custom_stream_decoder, # type: ignore ) ## Load Config @@ -937,12 +943,21 @@ def get_response_stream_shape(): class AWSEventStreamDecoder: - def __init__(self, model: str) -> None: + def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: from botocore.parsers import EventStreamJSONParser self.model = model self.parser = EventStreamJSONParser() self.content_blocks: List = [] + self.is_messages_api = is_messages_api + + def _chunk_parser_messages_api( + self, chunk_data: dict + ) -> StreamingChatCompletionChunk: + + openai_chunk = StreamingChatCompletionChunk(**chunk_data) + + return openai_chunk def _chunk_parser(self, chunk_data: dict) -> GChunk: verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) @@ -958,6 +973,7 @@ class AWSEventStreamDecoder: index=_index, is_finished=True, finish_reason="stop", + usage=None, ) return GChunk( @@ -965,9 +981,12 @@ class AWSEventStreamDecoder: index=_index, is_finished=is_finished, finish_reason=finish_reason, + usage=None, ) - def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: + def iter_bytes( + self, iterator: Iterator[bytes] + ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: """Given an iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -988,7 +1007,10 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -999,16 +1021,20 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) - except json.JSONDecodeError: + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError as e: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) + yield None async def aiter_bytes( self, iterator: AsyncIterator[bytes] - ) -> AsyncIterator[GChunk]: + ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -1030,7 +1056,10 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -1041,12 +1070,16 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) except json.JSONDecodeError: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) + yield None def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict()