diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index ec20a1ad4c..21005f5cdf 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1506,7 +1506,7 @@ class CustomStreamWrapper: chunk = self.completion_stream else: chunk = next(self.completion_stream) - if chunk is not None and chunk != b"": + if chunk is not None and chunk != b"" and chunk != ["DONE"]: print_verbose( f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}" ) diff --git a/tests/litellm/litellm_core_utils/test_streaming_handler.py b/tests/litellm/litellm_core_utils/test_streaming_handler.py index 81bde88f39..0f95c51928 100644 --- a/tests/litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/litellm/litellm_core_utils/test_streaming_handler.py @@ -5,6 +5,8 @@ import time from unittest.mock import MagicMock, Mock, patch import pytest +from openai import Stream +from openai.types.chat import ChatCompletionChunk sys.path.insert( 0, os.path.abspath("../../..") @@ -643,3 +645,55 @@ async def test_streaming_completion_start_time(logging_obj: Logging): logging_obj.model_call_details["completion_start_time"] < logging_obj.model_call_details["end_time"] ) + + +def test_unit_test_custom_stream_wrapper_ignores_last_chunk_when_done_token(): + """ + Test if the last streaming chunk is ["DONE"], it is not returned + """ + litellm.set_verbose = False + chunks = [ + ChatCompletionChunk( + id="chatcmpl-123", + object="chat.completion.chunk", + created=1694268190, + model="gpt-4o", + system_fingerprint="fp_44709d6fcb", + choices=[ + {"index": 0, "delta": {"content": "Hello", "role": "assistant"}, "finish_reason": None} + ], + ), + ChatCompletionChunk( + id="chatcmpl-123", + object="chat.completion.chunk", + created=1694268194, + model="gpt-4o", + system_fingerprint="fp_44709d6fcb", + choices=[ + {"index": 0, "delta": {"content": None, "role": None}, "finish_reason": "stop"} + ], + ), + ["DONE"], + ] + mock_stream = Stream.__new__(Stream) + mock_stream.__stream__ = MagicMock(return_value=iter(chunks)) + mock_stream._iterator = mock_stream.__stream__() + + response = litellm.CustomStreamWrapper( + completion_stream=mock_stream, + model="gpt-3.5-turbo", + custom_llm_provider="cached_response", + logging_obj=litellm.Logging( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey"}], + stream=True, + call_type="completion", + start_time=time.time(), + litellm_call_id="12345", + function_id="1245", + ), + ) + + chunks = list(response) + assert len(chunks) == 2 + assert [type(chunk) for chunk in chunks] == [ModelResponseStream, ModelResponseStream] \ No newline at end of file