fix(streaming_handler): ignore ["DONE"] tokens from streaming responses

This commit is contained in:
György Orosz 2025-04-15 16:59:58 +02:00
parent aff0d1a18c
commit 89a41a3b48
2 changed files with 55 additions and 1 deletions

View file

@ -1506,7 +1506,7 @@ class CustomStreamWrapper:
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
if chunk is not None and chunk != b"" and chunk != ["DONE"]:
print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
)

View file

@ -5,6 +5,8 @@ import time
from unittest.mock import MagicMock, Mock, patch
import pytest
from openai import Stream
from openai.types.chat import ChatCompletionChunk
sys.path.insert(
0, os.path.abspath("../../..")
@ -643,3 +645,55 @@ async def test_streaming_completion_start_time(logging_obj: Logging):
logging_obj.model_call_details["completion_start_time"]
< logging_obj.model_call_details["end_time"]
)
def test_unit_test_custom_stream_wrapper_ignores_last_chunk_when_done_token():
"""
Test if the last streaming chunk is ["DONE"], it is not returned
"""
litellm.set_verbose = False
chunks = [
ChatCompletionChunk(
id="chatcmpl-123",
object="chat.completion.chunk",
created=1694268190,
model="gpt-4o",
system_fingerprint="fp_44709d6fcb",
choices=[
{"index": 0, "delta": {"content": "Hello", "role": "assistant"}, "finish_reason": None}
],
),
ChatCompletionChunk(
id="chatcmpl-123",
object="chat.completion.chunk",
created=1694268194,
model="gpt-4o",
system_fingerprint="fp_44709d6fcb",
choices=[
{"index": 0, "delta": {"content": None, "role": None}, "finish_reason": "stop"}
],
),
["DONE"],
]
mock_stream = Stream.__new__(Stream)
mock_stream.__stream__ = MagicMock(return_value=iter(chunks))
mock_stream._iterator = mock_stream.__stream__()
response = litellm.CustomStreamWrapper(
completion_stream=mock_stream,
model="gpt-3.5-turbo",
custom_llm_provider="cached_response",
logging_obj=litellm.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
chunks = list(response)
assert len(chunks) == 2
assert [type(chunk) for chunk in chunks] == [ModelResponseStream, ModelResponseStream]