mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Bug fix - String data: stripped from entire content in streamed Gemini responses (#9070)
* _strip_sse_data_from_chunk * use _strip_sse_data_from_chunk * use _strip_sse_data_from_chunk * use _strip_sse_data_from_chunk * _strip_sse_data_from_chunk * test_strip_sse_data_from_chunk * _strip_sse_data_from_chunk * testing * _strip_sse_data_from_chunk
This commit is contained in:
parent
bf4cbd0cee
commit
571ef58045
7 changed files with 213 additions and 8 deletions
|
@ -637,7 +637,10 @@ class CustomStreamWrapper:
|
||||||
if isinstance(chunk, bytes):
|
if isinstance(chunk, bytes):
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
if "text_output" in chunk:
|
if "text_output" in chunk:
|
||||||
response = chunk.replace("data: ", "").strip()
|
response = (
|
||||||
|
CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||||
|
)
|
||||||
|
response = response.strip()
|
||||||
parsed_response = json.loads(response)
|
parsed_response = json.loads(response)
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
|
@ -1828,6 +1831,42 @@ class CustomStreamWrapper:
|
||||||
extra_kwargs={},
|
extra_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Strips the 'data: ' prefix from Server-Sent Events (SSE) chunks.
|
||||||
|
|
||||||
|
Some providers like sagemaker send it as `data:`, need to handle both
|
||||||
|
|
||||||
|
SSE messages are prefixed with 'data: ' which is part of the protocol,
|
||||||
|
not the actual content from the LLM. This method removes that prefix
|
||||||
|
and returns the actual content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk: The SSE chunk that may contain the 'data: ' prefix (string or bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The chunk with the 'data: ' prefix removed, or the original chunk
|
||||||
|
if no prefix was found. Returns None if input is None.
|
||||||
|
|
||||||
|
See OpenAI Python Ref for this: https://github.com/openai/openai-python/blob/041bf5a8ec54da19aad0169671793c2078bd6173/openai/api_requestor.py#L100
|
||||||
|
"""
|
||||||
|
if chunk is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
# OpenAI sends `data: `
|
||||||
|
if chunk.startswith("data: "):
|
||||||
|
# Strip the prefix and any leading whitespace that might follow it
|
||||||
|
_length_of_sse_data_prefix = len("data: ")
|
||||||
|
return chunk[_length_of_sse_data_prefix:]
|
||||||
|
elif chunk.startswith("data:"):
|
||||||
|
# Sagemaker sends `data:`, no trailing whitespace
|
||||||
|
_length_of_sse_data_prefix = len("data:")
|
||||||
|
return chunk[_length_of_sse_data_prefix:]
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
|
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
|
||||||
"""Assume most recent usage chunk has total usage uptil then."""
|
"""Assume most recent usage chunk has total usage uptil then."""
|
||||||
|
|
|
@ -84,7 +84,9 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
chunk_data = chunk_data.replace("data:", "")
|
chunk_data = (
|
||||||
|
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk_data) or ""
|
||||||
|
)
|
||||||
chunk_data = chunk_data.strip()
|
chunk_data = chunk_data.strip()
|
||||||
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -89,7 +89,7 @@ class ModelResponseIterator:
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk = chunk.replace("data:", "")
|
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||||
chunk = chunk.strip()
|
chunk = chunk.strip()
|
||||||
if len(chunk) > 0:
|
if len(chunk) > 0:
|
||||||
json_chunk = json.loads(chunk)
|
json_chunk = json.loads(chunk)
|
||||||
|
@ -134,7 +134,7 @@ class ModelResponseIterator:
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk = chunk.replace("data:", "")
|
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||||
chunk = chunk.strip()
|
chunk = chunk.strip()
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
@ -78,7 +79,11 @@ class AWSEventStreamDecoder:
|
||||||
message = self._parse_message_from_event(event)
|
message = self._parse_message_from_event(event)
|
||||||
if message:
|
if message:
|
||||||
# remove data: prefix and "\n\n" at the end
|
# remove data: prefix and "\n\n" at the end
|
||||||
message = message.replace("data:", "").replace("\n\n", "")
|
message = (
|
||||||
|
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
message = message.replace("\n\n", "")
|
||||||
|
|
||||||
# Accumulate JSON data
|
# Accumulate JSON data
|
||||||
accumulated_json += message
|
accumulated_json += message
|
||||||
|
@ -127,7 +132,11 @@ class AWSEventStreamDecoder:
|
||||||
if message:
|
if message:
|
||||||
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
|
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
|
||||||
# remove data: prefix and "\n\n" at the end
|
# remove data: prefix and "\n\n" at the end
|
||||||
message = message.replace("data:", "").replace("\n\n", "")
|
message = (
|
||||||
|
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
message = message.replace("\n\n", "")
|
||||||
|
|
||||||
# Accumulate JSON data
|
# Accumulate JSON data
|
||||||
accumulated_json += message
|
accumulated_json += message
|
||||||
|
|
|
@ -1408,7 +1408,8 @@ class ModelResponseIterator:
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
|
||||||
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||||
message = chunk.replace("data:", "").replace("\n\n", "")
|
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||||
|
message = chunk.replace("\n\n", "")
|
||||||
|
|
||||||
# Accumulate JSON data
|
# Accumulate JSON data
|
||||||
self.accumulated_json += message
|
self.accumulated_json += message
|
||||||
|
@ -1431,7 +1432,7 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
chunk = chunk.replace("data:", "")
|
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||||
if len(chunk) > 0:
|
if len(chunk) > 0:
|
||||||
"""
|
"""
|
||||||
Check if initial chunk valid json
|
Check if initial chunk valid json
|
||||||
|
|
133
tests/code_coverage_tests/check_data_replace_usage.py
Normal file
133
tests/code_coverage_tests/check_data_replace_usage.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class DataReplaceVisitor(ast.NodeVisitor):
|
||||||
|
"""AST visitor that finds calls to .replace("data:", ...) in the code."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.issues = []
|
||||||
|
self.current_file = None
|
||||||
|
|
||||||
|
def set_file(self, filename):
|
||||||
|
self.current_file = filename
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
# Check for method calls like x.replace(...)
|
||||||
|
if isinstance(node.func, ast.Attribute) and node.func.attr == "replace":
|
||||||
|
# Check if first argument is "data:"
|
||||||
|
if (
|
||||||
|
len(node.args) >= 2
|
||||||
|
and isinstance(node.args[0], ast.Constant)
|
||||||
|
and isinstance(node.args[0].value, str)
|
||||||
|
and "data:" in node.args[0].value
|
||||||
|
):
|
||||||
|
|
||||||
|
self.issues.append(
|
||||||
|
{
|
||||||
|
"file": self.current_file,
|
||||||
|
"line": node.lineno,
|
||||||
|
"col": node.col_offset,
|
||||||
|
"text": f'Found .replace("data:", ...) at line {node.lineno}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Continue visiting child nodes
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_with_ast(file_path):
|
||||||
|
"""Check a Python file for .replace("data:", ...) using AST parsing."""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
tree = ast.parse(f.read(), filename=file_path)
|
||||||
|
visitor = DataReplaceVisitor()
|
||||||
|
visitor.set_file(file_path)
|
||||||
|
visitor.visit(tree)
|
||||||
|
return visitor.issues
|
||||||
|
except SyntaxError:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"file": file_path,
|
||||||
|
"line": 0,
|
||||||
|
"col": 0,
|
||||||
|
"text": f"Syntax error in file, could not parse",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_with_regex(file_path):
|
||||||
|
"""Check any file for .replace("data:", ...) using regex."""
|
||||||
|
issues = []
|
||||||
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
|
for i, line in enumerate(f, 1):
|
||||||
|
matches = re.finditer(r'\.replace\(\s*[\'"]data:[\'"]', line)
|
||||||
|
for match in matches:
|
||||||
|
issues.append(
|
||||||
|
{
|
||||||
|
"file": file_path,
|
||||||
|
"line": i,
|
||||||
|
"col": match.start(),
|
||||||
|
"text": f'Found .replace("data:", ...) at line {i}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
def scan_directory(base_dir):
|
||||||
|
"""Scan a directory recursively for files containing .replace("data:", ...)."""
|
||||||
|
all_issues = []
|
||||||
|
|
||||||
|
for root, _, files in os.walk(base_dir):
|
||||||
|
for file in files:
|
||||||
|
print("checking file: ", file)
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
|
||||||
|
# Skip directories we don't want to check
|
||||||
|
if any(
|
||||||
|
d in file_path for d in [".git", "__pycache__", ".venv", "node_modules"]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For Python files, use AST for more accurate parsing
|
||||||
|
if file.endswith(".py"):
|
||||||
|
issues = check_file_with_ast(file_path)
|
||||||
|
# For other files that might contain code, use regex
|
||||||
|
elif file.endswith((".js", ".ts", ".jsx", ".tsx", ".md", ".ipynb")):
|
||||||
|
issues = check_file_with_regex(file_path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_issues.extend(issues)
|
||||||
|
|
||||||
|
return all_issues
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Start from the project root directory
|
||||||
|
|
||||||
|
base_dir = "./litellm"
|
||||||
|
|
||||||
|
# Local testing
|
||||||
|
# base_dir = "../../litellm"
|
||||||
|
|
||||||
|
print(f"Scanning for .replace('data:', ...) usage in {base_dir}")
|
||||||
|
issues = scan_directory(base_dir)
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
print(f"\n⚠️ Found {len(issues)} instances of .replace('data:', ...):")
|
||||||
|
for issue in issues:
|
||||||
|
print(f"{issue['file']}:{issue['line']} - {issue['text']}")
|
||||||
|
|
||||||
|
# Fail the test if issues are found
|
||||||
|
raise Exception(
|
||||||
|
f"Found {len(issues)} instances of .replace('data:', ...) which may be unsafe. Use litellm.CustomStreamWrapper._strip_sse_data_from_chunk instead."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("✅ No instances of .replace('data:', ...) found.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -256,3 +256,24 @@ def test_multi_chunk_reasoning_and_content(
|
||||||
# Verify final state
|
# Verify final state
|
||||||
assert initialized_custom_stream_wrapper.sent_first_thinking_block is True
|
assert initialized_custom_stream_wrapper.sent_first_thinking_block is True
|
||||||
assert initialized_custom_stream_wrapper.sent_last_thinking_block is True
|
assert initialized_custom_stream_wrapper.sent_last_thinking_block is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_sse_data_from_chunk():
|
||||||
|
"""Test the static method that strips 'data: ' prefix from SSE chunks"""
|
||||||
|
# Test with string inputs
|
||||||
|
assert CustomStreamWrapper._strip_sse_data_from_chunk("data: content") == "content"
|
||||||
|
assert (
|
||||||
|
CustomStreamWrapper._strip_sse_data_from_chunk("data: spaced content")
|
||||||
|
== " spaced content"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
CustomStreamWrapper._strip_sse_data_from_chunk("regular content")
|
||||||
|
== "regular content"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
CustomStreamWrapper._strip_sse_data_from_chunk("regular content with data:")
|
||||||
|
== "regular content with data:"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with None input
|
||||||
|
assert CustomStreamWrapper._strip_sse_data_from_chunk(None) is None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue