forked from phoenix/litellm-mirror
fix(sagemaker.py): fix streaming logic
This commit is contained in:
parent
c7bbfef846
commit
bf81b484c6
1 changed files with 43 additions and 10 deletions
|
@ -27,6 +27,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import StreamingChatCompletionChunk
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
|
@ -35,8 +36,8 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
)
|
||||
|
||||
from .base_aws_llm import BaseAWSLLM
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
_response_stream_shape_cache = None
|
||||
|
||||
|
@ -285,6 +286,10 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
custom_stream_decoder = AWSEventStreamDecoder(
|
||||
model="", is_messages_api=True
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -303,6 +308,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
headers=prepared_request.headers,
|
||||
custom_endpoint=True,
|
||||
custom_llm_provider="sagemaker_chat",
|
||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
@ -937,12 +943,21 @@ def get_response_stream_shape():
|
|||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self, model: str) -> None:
|
||||
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
self.content_blocks: List = []
|
||||
self.is_messages_api = is_messages_api
|
||||
|
||||
def _chunk_parser_messages_api(
|
||||
self, chunk_data: dict
|
||||
) -> StreamingChatCompletionChunk:
|
||||
|
||||
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
||||
|
||||
return openai_chunk
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||
|
@ -958,6 +973,7 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
return GChunk(
|
||||
|
@ -965,9 +981,12 @@ class AWSEventStreamDecoder:
|
|||
index=_index,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
||||
def iter_bytes(
|
||||
self, iterator: Iterator[bytes]
|
||||
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -988,6 +1007,9 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
|
@ -999,16 +1021,20 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[GChunk]:
|
||||
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
|
@ -1030,6 +1056,9 @@ class AWSEventStreamDecoder:
|
|||
# Try to parse the accumulated JSON
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
# Reset accumulated_json after successful parsing
|
||||
accumulated_json = ""
|
||||
|
@ -1041,12 +1070,16 @@ class AWSEventStreamDecoder:
|
|||
if accumulated_json:
|
||||
try:
|
||||
_data = json.loads(accumulated_json)
|
||||
if self.is_messages_api:
|
||||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
)
|
||||
yield None
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue