forked from phoenix/litellm-mirror
fix(sagemaker.py): fix streaming logic
This commit is contained in:
parent
c7bbfef846
commit
bf81b484c6
1 changed files with 43 additions and 10 deletions
|
@ -27,6 +27,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
from litellm.types.utils import StreamingChatCompletionChunk
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
|
@ -35,8 +36,8 @@ from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .base_aws_llm import BaseAWSLLM
|
from ..base_aws_llm import BaseAWSLLM
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
_response_stream_shape_cache = None
|
_response_stream_shape_cache = None
|
||||||
|
|
||||||
|
@ -285,6 +286,10 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_stream_decoder = AWSEventStreamDecoder(
|
||||||
|
model="", is_messages_api=True
|
||||||
|
)
|
||||||
|
|
||||||
return openai_like_chat_completions.completion(
|
return openai_like_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -303,6 +308,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
headers=prepared_request.headers,
|
headers=prepared_request.headers,
|
||||||
custom_endpoint=True,
|
custom_endpoint=True,
|
||||||
custom_llm_provider="sagemaker_chat",
|
custom_llm_provider="sagemaker_chat",
|
||||||
|
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -937,12 +943,21 @@ def get_response_stream_shape():
|
||||||
|
|
||||||
|
|
||||||
class AWSEventStreamDecoder:
|
class AWSEventStreamDecoder:
|
||||||
def __init__(self, model: str) -> None:
|
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
|
||||||
from botocore.parsers import EventStreamJSONParser
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
self.content_blocks: List = []
|
self.content_blocks: List = []
|
||||||
|
self.is_messages_api = is_messages_api
|
||||||
|
|
||||||
|
def _chunk_parser_messages_api(
|
||||||
|
self, chunk_data: dict
|
||||||
|
) -> StreamingChatCompletionChunk:
|
||||||
|
|
||||||
|
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
|
||||||
|
|
||||||
|
return openai_chunk
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GChunk:
|
||||||
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
|
||||||
|
@ -958,6 +973,7 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=True,
|
is_finished=True,
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
|
usage=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GChunk(
|
return GChunk(
|
||||||
|
@ -965,9 +981,12 @@ class AWSEventStreamDecoder:
|
||||||
index=_index,
|
index=_index,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
usage=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
|
def iter_bytes(
|
||||||
|
self, iterator: Iterator[bytes]
|
||||||
|
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -988,7 +1007,10 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
@ -999,16 +1021,20 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
except json.JSONDecodeError:
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
|
yield None
|
||||||
|
|
||||||
async def aiter_bytes(
|
async def aiter_bytes(
|
||||||
self, iterator: AsyncIterator[bytes]
|
self, iterator: AsyncIterator[bytes]
|
||||||
) -> AsyncIterator[GChunk]:
|
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
|
||||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
@ -1030,7 +1056,10 @@ class AWSEventStreamDecoder:
|
||||||
# Try to parse the accumulated JSON
|
# Try to parse the accumulated JSON
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
# Reset accumulated_json after successful parsing
|
# Reset accumulated_json after successful parsing
|
||||||
accumulated_json = ""
|
accumulated_json = ""
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
@ -1041,12 +1070,16 @@ class AWSEventStreamDecoder:
|
||||||
if accumulated_json:
|
if accumulated_json:
|
||||||
try:
|
try:
|
||||||
_data = json.loads(accumulated_json)
|
_data = json.loads(accumulated_json)
|
||||||
yield self._chunk_parser(chunk_data=_data)
|
if self.is_messages_api:
|
||||||
|
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||||
|
else:
|
||||||
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Handle or log any unparseable data at the end
|
# Handle or log any unparseable data at the end
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||||
)
|
)
|
||||||
|
yield None
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue