fix(sagemaker.py): support streaming for messages api

Fixes https://github.com/BerriAI/litellm/issues/5372
This commit is contained in:
Krrish Dholakia 2024-08-26 15:08:08 -07:00
parent 174b1c43e3
commit 8e9acd117b
8 changed files with 142 additions and 32 deletions

View file

@ -856,7 +856,7 @@ from .llms.vertex_httpx import (
from .llms.vertex_ai import VertexAITextEmbeddingConfig from .llms.vertex_ai import VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.vertex_ai_partner import VertexAILlama3Config from .llms.vertex_ai_partner import VertexAILlama3Config
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig

View file

@ -13,4 +13,6 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
# this is an optional field in GenericStreamingChunk, it's not required to be present # this is an optional field in GenericStreamingChunk, it's not required to be present
_all_fields.pop("provider_specific_fields", None) _all_fields.pop("provider_specific_fields", None)
return all(key in chunk for key in _all_fields) decision = all(key in _all_fields for key in chunk)
return decision

View file

@ -7,7 +7,7 @@ import time
import types import types
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -22,7 +22,11 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.utils import GenericStreamingChunk, ProviderField from litellm.types.utils import (
CustomStreamingDecoder,
GenericStreamingChunk,
ProviderField,
)
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
@ -171,15 +175,21 @@ async def make_call(
model: str, model: str,
messages: list, messages: list,
logging_obj, logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
): ):
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text) raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator( if streaming_decoder is not None:
streaming_response=response.aiter_lines(), sync_stream=False completion_stream: Any = streaming_decoder.aiter_bytes(
) response.aiter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -199,6 +209,7 @@ def make_sync_call(
model: str, model: str,
messages: list, messages: list,
logging_obj, logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
): ):
if client is None: if client is None:
client = HTTPHandler() # Create a new client if none provided client = HTTPHandler() # Create a new client if none provided
@ -208,9 +219,14 @@ def make_sync_call(
if response.status_code != 200: if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read()) raise DatabricksError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator( if streaming_decoder is not None:
streaming_response=response.iter_lines(), sync_stream=True completion_stream = streaming_decoder.iter_bytes(
) response.iter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
data["stream"] = True data["stream"] = True
@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
), ),
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None, custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
): ):
custom_endpoint = custom_endpoint or optional_params.pop( custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None "custom_endpoint", None
@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM):
headers=headers, headers=headers,
client=client, client=client,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
streaming_decoder=streaming_decoder,
) )
else: else:
return self.acompletion_function( return self.acompletion_function(
@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
), ),
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,

View file

@ -24,8 +24,11 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
) )
from litellm.types.utils import CustomStreamingDecoder
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
@ -34,8 +37,8 @@ from litellm.utils import (
get_secret, get_secret,
) )
from .base_aws_llm import BaseAWSLLM from ..base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory from ..prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None _response_stream_shape_cache = None
@ -241,6 +244,10 @@ class SagemakerLLM(BaseAWSLLM):
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion( return openai_like_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -259,6 +266,7 @@ class SagemakerLLM(BaseAWSLLM):
headers=prepared_request.headers, headers=prepared_request.headers,
custom_endpoint=True, custom_endpoint=True,
custom_llm_provider="sagemaker_chat", custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
) )
## Load Config ## Load Config
@ -332,7 +340,7 @@ class SagemakerLLM(BaseAWSLLM):
) )
return response return response
else: else:
if stream is not None and stream == True: if stream is not None and stream is True:
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()
sync_response = sync_handler.post( sync_response = sync_handler.post(
url=prepared_request.url, url=prepared_request.url,
@ -847,12 +855,21 @@ def get_response_stream_shape():
class AWSEventStreamDecoder: class AWSEventStreamDecoder:
def __init__(self, model: str) -> None: def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser from botocore.parsers import EventStreamJSONParser
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
self.content_blocks: List = [] self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk: def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
@ -868,6 +885,7 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=True, is_finished=True,
finish_reason="stop", finish_reason="stop",
usage=None,
) )
return GChunk( return GChunk(
@ -875,9 +893,12 @@ class AWSEventStreamDecoder:
index=_index, index=_index,
is_finished=is_finished, is_finished=is_finished,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=None,
) )
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" """Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -898,7 +919,10 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data) if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
except json.JSONDecodeError: except json.JSONDecodeError:
@ -909,16 +933,20 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data) if self.is_messages_api:
except json.JSONDecodeError: yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
async def aiter_bytes( async def aiter_bytes(
self, iterator: AsyncIterator[bytes] self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GChunk]: ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -940,7 +968,10 @@ class AWSEventStreamDecoder:
# Try to parse the accumulated JSON # Try to parse the accumulated JSON
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data) if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing # Reset accumulated_json after successful parsing
accumulated_json = "" accumulated_json = ""
except json.JSONDecodeError: except json.JSONDecodeError:
@ -951,12 +982,16 @@ class AWSEventStreamDecoder:
if accumulated_json: if accumulated_json:
try: try:
_data = json.loads(accumulated_json) _data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data) if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError: except json.JSONDecodeError:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"
) )
yield None
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()

View file

@ -119,7 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
from .llms.sagemaker import SagemakerLLM from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion

View file

@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.parametrize("sync_mode", [False, True])
async def test_completion_sagemaker_stream(sync_mode): @pytest.mark.parametrize(
"model",
[
"sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245",
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
async def test_completion_sagemaker_stream(sync_mode, model):
try: try:
from litellm.tests.test_streaming import streaming_format_tests
litellm.set_verbose = False litellm.set_verbose = False
print("testing sagemaker") print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)
full_text = "" full_text = ""
if sync_mode is True: if sync_mode is True:
response = litellm.completion( response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", model=model,
messages=[ messages=[
{"role": "user", "content": "hi - what is ur name"}, {"role": "user", "content": "hi - what is ur name"},
], ],
@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode):
input_cost_per_second=0.000420, input_cost_per_second=0.000420,
) )
for chunk in response: for idx, chunk in enumerate(response):
print(chunk) print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or "" full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text) print("SYNC RESPONSE full text", full_text)
else: else:
response = await litellm.acompletion( response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", model=model,
messages=[ messages=[
{"role": "user", "content": "hi - what is ur name"}, {"role": "user", "content": "hi - what is ur name"},
], ],
@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode):
) )
print("streaming response") print("streaming response")
idx = 0
async for chunk in response: async for chunk in response:
print(chunk) print(chunk)
streaming_format_tests(idx=idx, chunk=chunk)
full_text += chunk.choices[0].delta.content or "" full_text += chunk.choices[0].delta.content or ""
idx += 1
print("ASYNC RESPONSE full text", full_text) print("ASYNC RESPONSE full text", full_text)

View file

@ -29,6 +29,7 @@ from openai.types.beta.thread_create_params import (
from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run from openai.types.beta.threads.run import Run
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Dict, Required, TypedDict, override
@ -456,6 +457,13 @@ class ChatCompletionUsageBlock(TypedDict):
total_tokens: int total_tokens: int
class OpenAIChatCompletionChunk(ChatCompletionChunk):
def __init__(self, **kwargs):
# Set the 'object' kwarg to 'chat.completion.chunk'
kwargs["object"] = "chat.completion.chunk"
super().__init__(**kwargs)
class Hyperparameters(BaseModel): class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[Union[str, float]] = ( learning_rate_multiplier: Optional[Union[str, float]] = (

View file

@ -5,11 +5,16 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from openai.types.completion_usage import CompletionUsage
from pydantic import ConfigDict, Field, PrivateAttr from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock from .llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
OpenAIChatCompletionChunk,
)
def _generate_id(): # private helper function def _generate_id(): # private helper function
@ -85,7 +90,7 @@ class GenericStreamingChunk(TypedDict, total=False):
tool_use: Optional[ChatCompletionToolCallChunk] tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool] is_finished: Required[bool]
finish_reason: Required[str] finish_reason: Required[str]
usage: Optional[ChatCompletionUsageBlock] usage: Required[Optional[ChatCompletionUsageBlock]]
index: int index: int
# use this dict if you want to return any provider specific fields in the response # use this dict if you want to return any provider specific fields in the response
@ -448,9 +453,6 @@ class Choices(OpenAIObject):
setattr(self, key, value) setattr(self, key, value)
from openai.types.completion_usage import CompletionUsage
class Usage(CompletionUsage): class Usage(CompletionUsage):
def __init__( def __init__(
self, self,
@ -535,6 +537,17 @@ class StreamingChoices(OpenAIObject):
setattr(self, key, value) setattr(self, key, value)
class StreamingChatCompletionChunk(OpenAIChatCompletionChunk):
def __init__(self, **kwargs):
new_choices = []
for choice in kwargs["choices"]:
new_choice = StreamingChoices(**choice).model_dump()
new_choices.append(new_choice)
kwargs["choices"] = new_choices
super().__init__(**kwargs)
class ModelResponse(OpenAIObject): class ModelResponse(OpenAIObject):
id: str id: str
"""A unique identifier for the completion.""" """A unique identifier for the completion."""
@ -1231,3 +1244,20 @@ class StandardLoggingPayload(TypedDict):
response: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]]
model_parameters: dict model_parameters: dict
hidden_params: StandardLoggingHiddenParams hidden_params: StandardLoggingHiddenParams
from typing import AsyncIterator, Iterator
class CustomStreamingDecoder:
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[
Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]
]:
raise NotImplementedError
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
raise NotImplementedError