From 8e9acd117bcf8fe14800936c9e8c9ec058d2526c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Aug 2024 15:08:08 -0700 Subject: [PATCH] fix(sagemaker.py): support streaming for messages api Fixes https://github.com/BerriAI/litellm/issues/5372 --- litellm/__init__.py | 2 +- litellm/litellm_core_utils/streaming_utils.py | 4 +- litellm/llms/databricks.py | 39 ++++++++++--- litellm/llms/{ => sagemaker}/sagemaker.py | 57 +++++++++++++++---- litellm/main.py | 2 +- litellm/tests/test_sagemaker.py | 22 +++++-- litellm/types/llms/openai.py | 8 +++ litellm/types/utils.py | 40 +++++++++++-- 8 files changed, 142 insertions(+), 32 deletions(-) rename litellm/llms/{ => sagemaker}/sagemaker.py (94%) diff --git a/litellm/__init__.py b/litellm/__init__.py index bda12e5dd..6e2309a8c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -856,7 +856,7 @@ from .llms.vertex_httpx import ( from .llms.vertex_ai import VertexAITextEmbeddingConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_partner import VertexAILlama3Config -from .llms.sagemaker import SagemakerConfig +from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig diff --git a/litellm/litellm_core_utils/streaming_utils.py b/litellm/litellm_core_utils/streaming_utils.py index ca8d58e9f..386842fe2 100644 --- a/litellm/litellm_core_utils/streaming_utils.py +++ b/litellm/litellm_core_utils/streaming_utils.py @@ -13,4 +13,6 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool: # this is an optional field in GenericStreamingChunk, it's not required to be present _all_fields.pop("provider_specific_fields", None) - return all(key in chunk for key in _all_fields) + decision = all(key in _all_fields for key in chunk) + + return decision diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index bd529046a..9b321cd33 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -7,7 +7,7 @@ import time import types from enum import Enum from functools import partial -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore @@ -22,7 +22,11 @@ from litellm.types.llms.openai import ( ChatCompletionToolCallFunctionChunk, ChatCompletionUsageBlock, ) -from litellm.types.utils import GenericStreamingChunk, ProviderField +from litellm.types.utils import ( + CustomStreamingDecoder, + GenericStreamingChunk, + ProviderField, +) from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from .base import BaseLLM @@ -171,15 +175,21 @@ async def make_call( model: str, messages: list, logging_obj, + streaming_decoder: Optional[CustomStreamingDecoder] = None, ): response = await client.post(api_base, headers=headers, data=data, stream=True) if response.status_code != 200: raise DatabricksError(status_code=response.status_code, message=response.text) - completion_stream = ModelResponseIterator( - streaming_response=response.aiter_lines(), sync_stream=False - ) + if streaming_decoder is not None: + completion_stream: Any = streaming_decoder.aiter_bytes( + response.aiter_bytes(chunk_size=1024) + ) + else: + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_lines(), sync_stream=False + ) # LOGGING logging_obj.post_call( input=messages, @@ -199,6 +209,7 @@ def make_sync_call( model: str, messages: list, logging_obj, + streaming_decoder: Optional[CustomStreamingDecoder] = None, ): if client is None: client = HTTPHandler() # Create a new client if none provided @@ -208,9 +219,14 @@ def make_sync_call( if response.status_code != 200: raise DatabricksError(status_code=response.status_code, message=response.read()) - completion_stream = ModelResponseIterator( - streaming_response=response.iter_lines(), sync_stream=True - ) + if streaming_decoder is not None: + completion_stream = streaming_decoder.iter_bytes( + response.iter_bytes(chunk_size=1024) + ) + else: + completion_stream = ModelResponseIterator( + streaming_response=response.iter_lines(), sync_stream=True + ) # LOGGING logging_obj.post_call( @@ -283,6 +299,7 @@ class DatabricksChatCompletion(BaseLLM): logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, ) -> CustomStreamWrapper: data["stream"] = True @@ -296,6 +313,7 @@ class DatabricksChatCompletion(BaseLLM): model=model, messages=messages, logging_obj=logging_obj, + streaming_decoder=streaming_decoder, ), model=model, custom_llm_provider=custom_llm_provider, @@ -371,6 +389,9 @@ class DatabricksChatCompletion(BaseLLM): timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[ + CustomStreamingDecoder + ] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker ): custom_endpoint = custom_endpoint or optional_params.pop( "custom_endpoint", None @@ -436,6 +457,7 @@ class DatabricksChatCompletion(BaseLLM): headers=headers, client=client, custom_llm_provider=custom_llm_provider, + streaming_decoder=streaming_decoder, ) else: return self.acompletion_function( @@ -473,6 +495,7 @@ class DatabricksChatCompletion(BaseLLM): model=model, messages=messages, logging_obj=logging_obj, + streaming_decoder=streaming_decoder, ), model=model, custom_llm_provider=custom_llm_provider, diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py similarity index 94% rename from litellm/llms/sagemaker.py rename to litellm/llms/sagemaker/sagemaker.py index 33be2efb8..32f73f7ee 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -24,8 +24,11 @@ from litellm.llms.custom_httpx.http_handler import ( from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, ChatCompletionUsageBlock, + OpenAIChatCompletionChunk, ) +from litellm.types.utils import CustomStreamingDecoder from litellm.types.utils import GenericStreamingChunk as GChunk +from litellm.types.utils import StreamingChatCompletionChunk from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -34,8 +37,8 @@ from litellm.utils import ( get_secret, ) -from .base_aws_llm import BaseAWSLLM -from .prompt_templates.factory import custom_prompt, prompt_factory +from ..base_aws_llm import BaseAWSLLM +from ..prompt_templates.factory import custom_prompt, prompt_factory _response_stream_shape_cache = None @@ -241,6 +244,10 @@ class SagemakerLLM(BaseAWSLLM): aws_region_name=aws_region_name, ) + custom_stream_decoder = AWSEventStreamDecoder( + model="", is_messages_api=True + ) + return openai_like_chat_completions.completion( model=model, messages=messages, @@ -259,6 +266,7 @@ class SagemakerLLM(BaseAWSLLM): headers=prepared_request.headers, custom_endpoint=True, custom_llm_provider="sagemaker_chat", + streaming_decoder=custom_stream_decoder, # type: ignore ) ## Load Config @@ -332,7 +340,7 @@ class SagemakerLLM(BaseAWSLLM): ) return response else: - if stream is not None and stream == True: + if stream is not None and stream is True: sync_handler = _get_httpx_client() sync_response = sync_handler.post( url=prepared_request.url, @@ -847,12 +855,21 @@ def get_response_stream_shape(): class AWSEventStreamDecoder: - def __init__(self, model: str) -> None: + def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: from botocore.parsers import EventStreamJSONParser self.model = model self.parser = EventStreamJSONParser() self.content_blocks: List = [] + self.is_messages_api = is_messages_api + + def _chunk_parser_messages_api( + self, chunk_data: dict + ) -> StreamingChatCompletionChunk: + + openai_chunk = StreamingChatCompletionChunk(**chunk_data) + + return openai_chunk def _chunk_parser(self, chunk_data: dict) -> GChunk: verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) @@ -868,6 +885,7 @@ class AWSEventStreamDecoder: index=_index, is_finished=True, finish_reason="stop", + usage=None, ) return GChunk( @@ -875,9 +893,12 @@ class AWSEventStreamDecoder: index=_index, is_finished=is_finished, finish_reason=finish_reason, + usage=None, ) - def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: + def iter_bytes( + self, iterator: Iterator[bytes] + ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: """Given an iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -898,7 +919,10 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -909,16 +933,20 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) - except json.JSONDecodeError: + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError as e: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) + yield None async def aiter_bytes( self, iterator: AsyncIterator[bytes] - ) -> AsyncIterator[GChunk]: + ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer @@ -940,7 +968,10 @@ class AWSEventStreamDecoder: # Try to parse the accumulated JSON try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) # Reset accumulated_json after successful parsing accumulated_json = "" except json.JSONDecodeError: @@ -951,12 +982,16 @@ class AWSEventStreamDecoder: if accumulated_json: try: _data = json.loads(accumulated_json) - yield self._chunk_parser(chunk_data=_data) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) except json.JSONDecodeError: # Handle or log any unparseable data at the end verbose_logger.error( f"Warning: Unparseable JSON data remained: {accumulated_json}" ) + yield None def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() diff --git a/litellm/main.py b/litellm/main.py index 2cf836890..adc711735 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -119,7 +119,7 @@ from .llms.prompt_templates.factory import ( prompt_factory, stringify_json_tool_call_content, ) -from .llms.sagemaker import SagemakerLLM +from .llms.sagemaker.sagemaker import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI from .llms.triton import TritonChatCompletion diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py index f407191f8..f0b77af4f 100644 --- a/litellm/tests/test_sagemaker.py +++ b/litellm/tests/test_sagemaker.py @@ -120,15 +120,24 @@ async def test_completion_sagemaker_messages_api(sync_mode): @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [False, True]) -async def test_completion_sagemaker_stream(sync_mode): +@pytest.mark.parametrize( + "model", + [ + "sagemaker_chat/huggingface-pytorch-tgi-inference-2024-08-23-15-48-59-245", + "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + ], +) +async def test_completion_sagemaker_stream(sync_mode, model): try: + from litellm.tests.test_streaming import streaming_format_tests + litellm.set_verbose = False print("testing sagemaker") verbose_logger.setLevel(logging.DEBUG) full_text = "" if sync_mode is True: response = litellm.completion( - model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + model=model, messages=[ {"role": "user", "content": "hi - what is ur name"}, ], @@ -138,14 +147,15 @@ async def test_completion_sagemaker_stream(sync_mode): input_cost_per_second=0.000420, ) - for chunk in response: + for idx, chunk in enumerate(response): print(chunk) + streaming_format_tests(idx=idx, chunk=chunk) full_text += chunk.choices[0].delta.content or "" print("SYNC RESPONSE full text", full_text) else: response = await litellm.acompletion( - model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + model=model, messages=[ {"role": "user", "content": "hi - what is ur name"}, ], @@ -156,10 +166,12 @@ async def test_completion_sagemaker_stream(sync_mode): ) print("streaming response") - + idx = 0 async for chunk in response: print(chunk) + streaming_format_tests(idx=idx, chunk=chunk) full_text += chunk.choices[0].delta.content or "" + idx += 1 print("ASYNC RESPONSE full text", full_text) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 5d2c416f9..8d7520f25 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -29,6 +29,7 @@ from openai.types.beta.thread_create_params import ( from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.run import Run +from openai.types.chat import ChatCompletionChunk from pydantic import BaseModel, Field from typing_extensions import Dict, Required, TypedDict, override @@ -456,6 +457,13 @@ class ChatCompletionUsageBlock(TypedDict): total_tokens: int +class OpenAIChatCompletionChunk(ChatCompletionChunk): + def __init__(self, **kwargs): + # Set the 'object' kwarg to 'chat.completion.chunk' + kwargs["object"] = "chat.completion.chunk" + super().__init__(**kwargs) + + class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." learning_rate_multiplier: Optional[Union[str, float]] = ( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 14d5cd1b8..3e5465bdc 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -5,11 +5,16 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Tuple, Union from openai._models import BaseModel as OpenAIObject +from openai.types.completion_usage import CompletionUsage from pydantic import ConfigDict, Field, PrivateAttr from typing_extensions import Callable, Dict, Required, TypedDict, override from ..litellm_core_utils.core_helpers import map_finish_reason -from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock +from .llms.openai import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + OpenAIChatCompletionChunk, +) def _generate_id(): # private helper function @@ -85,7 +90,7 @@ class GenericStreamingChunk(TypedDict, total=False): tool_use: Optional[ChatCompletionToolCallChunk] is_finished: Required[bool] finish_reason: Required[str] - usage: Optional[ChatCompletionUsageBlock] + usage: Required[Optional[ChatCompletionUsageBlock]] index: int # use this dict if you want to return any provider specific fields in the response @@ -448,9 +453,6 @@ class Choices(OpenAIObject): setattr(self, key, value) -from openai.types.completion_usage import CompletionUsage - - class Usage(CompletionUsage): def __init__( self, @@ -535,6 +537,17 @@ class StreamingChoices(OpenAIObject): setattr(self, key, value) +class StreamingChatCompletionChunk(OpenAIChatCompletionChunk): + def __init__(self, **kwargs): + + new_choices = [] + for choice in kwargs["choices"]: + new_choice = StreamingChoices(**choice).model_dump() + new_choices.append(new_choice) + kwargs["choices"] = new_choices + super().__init__(**kwargs) + + class ModelResponse(OpenAIObject): id: str """A unique identifier for the completion.""" @@ -1231,3 +1244,20 @@ class StandardLoggingPayload(TypedDict): response: Optional[Union[str, list, dict]] model_parameters: dict hidden_params: StandardLoggingHiddenParams + + +from typing import AsyncIterator, Iterator + + +class CustomStreamingDecoder: + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[ + Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]] + ]: + raise NotImplementedError + + def iter_bytes( + self, iterator: Iterator[bytes] + ) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]: + raise NotImplementedError