fix(sagemaker.py): support streaming for messages api

Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
Krrish Dholakia 2024-08-26 15:08:08 -07:00
parent 174b1c43e3
commit 8e9acd117b
8 changed files with 142 additions and 32 deletions

View file

@ -856,7 +856,7 @@ from .llms.vertex_httpx import (
from .llms.vertex_ai import VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.vertex_ai_partner import VertexAILlama3Config
from .llms.sagemaker import SagemakerConfig
from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig

View file

@ -13,4 +13,6 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
# this is an optional field in GenericStreamingChunk, it's not required to be present
_all_fields.pop("provider_specific_fields", None)
return all(key in chunk for key in _all_fields)
decision = all(key in _all_fields for key in chunk)
return decision

View file

@ -7,7 +7,7 @@ import time
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -22,7 +22,11 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.types.utils import (
CustomStreamingDecoder,
GenericStreamingChunk,
ProviderField,
)
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
@ -171,15 +175,21 @@ async def make_call(
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
if streaming_decoder is not None:
completion_stream: Any = streaming_decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
@ -199,6 +209,7 @@ def make_sync_call(
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
@ -208,9 +219,14 @@ def make_sync_call(
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
if streaming_decoder is not None:
completion_stream = streaming_decoder.iter_bytes(
response.iter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
) -> CustomStreamWrapper:
data["stream"] = True
@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,
@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
):
custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None
@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM):
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
streaming_decoder=streaming_decoder,
)
else:
return self.acompletion_function(
@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,

View file

@ -24,8 +24,11 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
)
from litellm.types.utils import CustomStreamingDecoder
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
@ -34,8 +37,8 @@ from litellm.utils import (
get_secret,
)
from .base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
from ..base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None
@ -241,6 +244,10 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name,
)
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
@ -259,6 +266,7 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers,
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
)
## Load Config
@ -332,7 +340,7 @@ class SagemakerLLM(BaseAWSLLM):
)
return response
else:
if stream is not None and stream == True:
if stream is not None and stream is True:
sync_handler = _get_httpx_client()
sync_response = sync_handler.post(
url=prepared_request.url,
@ -847,12 +855,21 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -868,6 +885,7 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
@ -875,9 +893,12 @@ class AWSEventStreamDecoder:
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -898,7 +919,10 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
@ -909,16 +933,20 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GChunk]:
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -940,7 +968,10 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
@ -951,12 +982,16 @@ class AWSEventStreamDecoder:
if accumulated_json:
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()

View file

@ -119,7 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker import SagemakerLLM
from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
from .llms.triton import TritonChatCompletion

View file

@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True])
async def test_completion_sagemaker_stream(sync_mode):
@pytest.mark.parametrize(
"model",
[
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
async def test_completion_sagemaker_stream(sync_mode, model):
try:
from litellm.tests.test_streaming import streaming_format_tests
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
input_cost_per_second=0.000420,
)
for chunk in response:
for idx, chunk in enumerate(response):
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
model=model,
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
)
print("streaming response")
idx = 0
async for chunk in response:
print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or ""
idx += 1
print("ASYNC RESPONSE full text", full_text)

View file

@ -29,6 +29,7 @@ from openai.types.beta.thread_create_params import (
from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, TypedDict, override
@ -456,6 +457,13 @@ class ChatCompletionUsageBlock(TypedDict):
total_tokens: int
class OpenAIChatCompletionChunk(ChatCompletionChunk):
def __init__(self, **kwargs):
# Set the 'object' kwarg to 'chat.completion.chunk'
kwargs["object"] = "chat.completion.chunk"
super().__init__(**kwargs)
class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[Union[str, float]] = (

View file

@ -5,11 +5,16 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from openai.types.completion_usage import CompletionUsage
from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
from .llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
)
def _generate_id(): # private helper function
@ -85,7 +90,7 @@ class GenericStreamingChunk(TypedDict, total=False):
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ChatCompletionUsageBlock]
usage: Required[Optional[ChatCompletionUsageBlock]]
index: int
# use this dict if you want to return any provider specific fields in the response
@ -448,9 +453,6 @@ class Choices(OpenAIObject):
setattr(self, key, value)
from openai.types.completion_usage import CompletionUsage
class Usage(CompletionUsage):
def __init__(
self,
@ -535,6 +537,17 @@ class StreamingChoices(OpenAIObject):
setattr(self, key, value)
class StreamingChatCompletionChunk(OpenAIChatCompletionChunk):
def __init__(self, **kwargs):
new_choices = []
for choice in kwargs["choices"]:
new_choice = StreamingChoices(**choice).model_dump()
new_choices.append(new_choice)
kwargs["choices"] = new_choices
super().__init__(**kwargs)
class ModelResponse(OpenAIObject):
id: str
"""A unique identifier for the completion."""
@ -1231,3 +1244,20 @@ class StandardLoggingPayload(TypedDict):
response: Optional[Union[str, list, dict]]
model_parameters: dict
hidden_params: StandardLoggingHiddenParams
from typing import AsyncIterator, Iterator
class CustomStreamingDecoder:
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[
Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]
]:
raise NotImplementedError
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
raise NotImplementedError