fix(sagemaker.py): fix streaming logic

This commit is contained in:
Krrish Dholakia 2024-08-27 08:10:47 -07:00
parent c7bbfef846
commit bf81b484c6

View file

@ -27,6 +27,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
@ -35,8 +36,8 @@ from litellm.utils import (
get_secret, get_secret,
) )
from .base_aws_llm import BaseAWSLLM from ..base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory from ..prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None _response_stream_shape_cache = None
@ -285,6 +286,10 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion( return openai_like_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -303,6 +308,7 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers, headers=prepared_request.headers,
custom_endpoint=True, custom_endpoint=True,
custom_llm_provider="sagemaker_chat", custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
) )
## Load Config ## Load Config
@ -937,12 +943,21 @@ def get_response_stream_shape():
class AWSEventStreamDecoder: class AWSEventStreamDecoder:
def __init__(self, model: str) -> None: def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser from botocore.parsers import EventStreamJSONParser
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
self.content_blocks: List = [] self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk: def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -958,6 +973,7 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=True, is_finished=True,
finish_reason="stop", finish_reason="stop",
usage=None,
) )
return GChunk( return GChunk(
@ -965,9 +981,12 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" """Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -988,6 +1007,9 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
@ -999,16 +1021,20 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError: except json.JSONDecodeError as e:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
async def aiter_bytes( async def aiter_bytes(
self, iterator: AsyncIterator[bytes] self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GChunk]: ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -1030,6 +1056,9 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
@ -1041,12 +1070,16 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError: except json.JSONDecodeError:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()