Bug fix - String data: stripped from entire content in streamed Gemini responses (#9070)

* _strip_sse_data_from_chunk

* use _strip_sse_data_from_chunk

* use _strip_sse_data_from_chunk

* use _strip_sse_data_from_chunk

* _strip_sse_data_from_chunk

* test_strip_sse_data_from_chunk

* _strip_sse_data_from_chunk

* testing

* _strip_sse_data_from_chunk
This commit is contained in:
Ishaan Jaff 2025-03-07 21:06:39 -08:00 committed by GitHub
parent bf4cbd0cee
commit 571ef58045
7 changed files with 213 additions and 8 deletions

View file

@ -637,7 +637,10 @@ class CustomStreamWrapper:
if isinstance(chunk, bytes): if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
if "text_output" in chunk: if "text_output" in chunk:
response = chunk.replace("data: ", "").strip() response = (
CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
)
response = response.strip()
parsed_response = json.loads(response) parsed_response = json.loads(response)
else: else:
return { return {
@ -1828,6 +1831,42 @@ class CustomStreamWrapper:
extra_kwargs={}, extra_kwargs={},
) )
@staticmethod
def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]:
"""
Strips the 'data: ' prefix from Server-Sent Events (SSE) chunks.
Some providers like sagemaker send it as `data:`, need to handle both
SSE messages are prefixed with 'data: ' which is part of the protocol,
not the actual content from the LLM. This method removes that prefix
and returns the actual content.
Args:
chunk: The SSE chunk that may contain the 'data: ' prefix (string or bytes)
Returns:
The chunk with the 'data: ' prefix removed, or the original chunk
if no prefix was found. Returns None if input is None.
See OpenAI Python Ref for this: https://github.com/openai/openai-python/blob/041bf5a8ec54da19aad0169671793c2078bd6173/openai/api_requestor.py#L100
"""
if chunk is None:
return None
if isinstance(chunk, str):
# OpenAI sends `data: `
if chunk.startswith("data: "):
# Strip the prefix and any leading whitespace that might follow it
_length_of_sse_data_prefix = len("data: ")
return chunk[_length_of_sse_data_prefix:]
elif chunk.startswith("data:"):
# Sagemaker sends `data:`, no trailing whitespace
_length_of_sse_data_prefix = len("data:")
return chunk[_length_of_sse_data_prefix:]
return chunk
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage: def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
"""Assume most recent usage chunk has total usage uptil then.""" """Assume most recent usage chunk has total usage uptil then."""

View file

@ -84,7 +84,9 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
finish_reason = None finish_reason = None
logprobs = None logprobs = None
chunk_data = chunk_data.replace("data:", "") chunk_data = (
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk_data) or ""
)
chunk_data = chunk_data.strip() chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]": if len(chunk_data) == 0 or chunk_data == "[DONE]":
return { return {

View file

@ -89,7 +89,7 @@ class ModelResponseIterator:
raise RuntimeError(f"Error receiving chunk from stream: {e}") raise RuntimeError(f"Error receiving chunk from stream: {e}")
try: try:
chunk = chunk.replace("data:", "") chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
chunk = chunk.strip() chunk = chunk.strip()
if len(chunk) > 0: if len(chunk) > 0:
json_chunk = json.loads(chunk) json_chunk = json.loads(chunk)
@ -134,7 +134,7 @@ class ModelResponseIterator:
raise RuntimeError(f"Error receiving chunk from stream: {e}") raise RuntimeError(f"Error receiving chunk from stream: {e}")
try: try:
chunk = chunk.replace("data:", "") chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
chunk = chunk.strip() chunk = chunk.strip()
if chunk == "[DONE]": if chunk == "[DONE]":
raise StopAsyncIteration raise StopAsyncIteration

View file

@ -3,6 +3,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
@ -78,7 +79,11 @@ class AWSEventStreamDecoder:
message = self._parse_message_from_event(event) message = self._parse_message_from_event(event)
if message: if message:
# remove data: prefix and "\n\n" at the end # remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "") message = (
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
or ""
)
message = message.replace("\n\n", "")
# Accumulate JSON data # Accumulate JSON data
accumulated_json += message accumulated_json += message
@ -127,7 +132,11 @@ class AWSEventStreamDecoder:
if message: if message:
verbose_logger.debug("sagemaker parsed chunk bytes %s", message) verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
# remove data: prefix and "\n\n" at the end # remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "") message = (
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
or ""
)
message = message.replace("\n\n", "")
# Accumulate JSON data # Accumulate JSON data
accumulated_json += message accumulated_json += message

View file

@ -1408,7 +1408,8 @@ class ModelResponseIterator:
return self.chunk_parser(chunk=json_chunk) return self.chunk_parser(chunk=json_chunk)
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk: def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
message = chunk.replace("data:", "").replace("\n\n", "") chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
message = chunk.replace("\n\n", "")
# Accumulate JSON data # Accumulate JSON data
self.accumulated_json += message self.accumulated_json += message
@ -1431,7 +1432,7 @@ class ModelResponseIterator:
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
try: try:
chunk = chunk.replace("data:", "") chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
if len(chunk) > 0: if len(chunk) > 0:
""" """
Check if initial chunk valid json Check if initial chunk valid json

View file

@ -0,0 +1,133 @@
import os
import re
import ast
from pathlib import Path
class DataReplaceVisitor(ast.NodeVisitor):
"""AST visitor that finds calls to .replace("data:", ...) in the code."""
def __init__(self):
self.issues = []
self.current_file = None
def set_file(self, filename):
self.current_file = filename
def visit_Call(self, node):
# Check for method calls like x.replace(...)
if isinstance(node.func, ast.Attribute) and node.func.attr == "replace":
# Check if first argument is "data:"
if (
len(node.args) >= 2
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
and "data:" in node.args[0].value
):
self.issues.append(
{
"file": self.current_file,
"line": node.lineno,
"col": node.col_offset,
"text": f'Found .replace("data:", ...) at line {node.lineno}',
}
)
# Continue visiting child nodes
self.generic_visit(node)
def check_file_with_ast(file_path):
"""Check a Python file for .replace("data:", ...) using AST parsing."""
with open(file_path, "r", encoding="utf-8") as f:
try:
tree = ast.parse(f.read(), filename=file_path)
visitor = DataReplaceVisitor()
visitor.set_file(file_path)
visitor.visit(tree)
return visitor.issues
except SyntaxError:
return [
{
"file": file_path,
"line": 0,
"col": 0,
"text": f"Syntax error in file, could not parse",
}
]
def check_file_with_regex(file_path):
"""Check any file for .replace("data:", ...) using regex."""
issues = []
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
for i, line in enumerate(f, 1):
matches = re.finditer(r'\.replace\(\s*[\'"]data:[\'"]', line)
for match in matches:
issues.append(
{
"file": file_path,
"line": i,
"col": match.start(),
"text": f'Found .replace("data:", ...) at line {i}',
}
)
return issues
def scan_directory(base_dir):
"""Scan a directory recursively for files containing .replace("data:", ...)."""
all_issues = []
for root, _, files in os.walk(base_dir):
for file in files:
print("checking file: ", file)
file_path = os.path.join(root, file)
# Skip directories we don't want to check
if any(
d in file_path for d in [".git", "__pycache__", ".venv", "node_modules"]
):
continue
# For Python files, use AST for more accurate parsing
if file.endswith(".py"):
issues = check_file_with_ast(file_path)
# For other files that might contain code, use regex
elif file.endswith((".js", ".ts", ".jsx", ".tsx", ".md", ".ipynb")):
issues = check_file_with_regex(file_path)
else:
continue
all_issues.extend(issues)
return all_issues
def main():
# Start from the project root directory
base_dir = "./litellm"
# Local testing
# base_dir = "../../litellm"
print(f"Scanning for .replace('data:', ...) usage in {base_dir}")
issues = scan_directory(base_dir)
if issues:
print(f"\n⚠️ Found {len(issues)} instances of .replace('data:', ...):")
for issue in issues:
print(f"{issue['file']}:{issue['line']} - {issue['text']}")
# Fail the test if issues are found
raise Exception(
f"Found {len(issues)} instances of .replace('data:', ...) which may be unsafe. Use litellm.CustomStreamWrapper._strip_sse_data_from_chunk instead."
)
else:
print("✅ No instances of .replace('data:', ...) found.")
if __name__ == "__main__":
main()

View file

@ -256,3 +256,24 @@ def test_multi_chunk_reasoning_and_content(
# Verify final state # Verify final state
assert initialized_custom_stream_wrapper.sent_first_thinking_block is True assert initialized_custom_stream_wrapper.sent_first_thinking_block is True
assert initialized_custom_stream_wrapper.sent_last_thinking_block is True assert initialized_custom_stream_wrapper.sent_last_thinking_block is True
def test_strip_sse_data_from_chunk():
"""Test the static method that strips 'data: ' prefix from SSE chunks"""
# Test with string inputs
assert CustomStreamWrapper._strip_sse_data_from_chunk("data: content") == "content"
assert (
CustomStreamWrapper._strip_sse_data_from_chunk("data: spaced content")
== " spaced content"
)
assert (
CustomStreamWrapper._strip_sse_data_from_chunk("regular content")
== "regular content"
)
assert (
CustomStreamWrapper._strip_sse_data_from_chunk("regular content with data:")
== "regular content with data:"
)
# Test with None input
assert CustomStreamWrapper._strip_sse_data_from_chunk(None) is None