Refactor Gemini stream parser logic: remove duplicated "single line/multi line" logic, just treat it as multiline every time.

This commit is contained in:
Mathis Beer 2025-03-04 16:31:50 +01:00
parent 40525a5974
commit 6fda0365b8

View file

@ -1312,7 +1312,6 @@ class VertexLLM(VertexBase):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
self.accumulated_json = ""
self.sent_first_chunk = False
@ -1389,78 +1388,38 @@ class ModelResponseIterator:
self.response_iterator = self.streaming_response
return self
def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
chunk = chunk.strip()
try:
json_chunk = json.loads(chunk)
except json.JSONDecodeError as e:
if (
self.sent_first_chunk is False
): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked.
self.chunk_type = "accumulated_json"
return self.handle_accumulated_json_chunk(chunk=chunk)
raise e
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return self.chunk_parser(chunk=json_chunk)
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
message = chunk.replace("\n\n", "")
# Accumulate JSON data
self.accumulated_json += message
# Try to parse the accumulated JSON
def _flush(self) -> Optional[GenericStreamingChunk]:
try:
_data = json.loads(self.accumulated_json)
self.accumulated_json = "" # reset after successful parsing
self.accumulated_json = ""
return self.chunk_parser(chunk=_data)
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
return None
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
try:
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
if len(chunk) > 0:
"""
Check if initial chunk valid json
- if partial json -> enter accumulated json logic
- if valid - continue
"""
if self.chunk_type == "valid_json":
return self.handle_valid_json_chunk(chunk=chunk)
elif self.chunk_type == "accumulated_json":
return self.handle_accumulated_json_chunk(chunk=chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except Exception:
raise
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
self.accumulated_json += chunk
chunk = self._flush()
if chunk:
return chunk
# If it's not valid JSON yet, continue to the next event
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
if self.chunk_type == "accumulated_json" and self.accumulated_json:
return self.handle_accumulated_json_chunk(chunk="")
if self.accumulated_json:
chunk = self._flush()
if chunk:
return chunk
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
@ -1481,8 +1440,10 @@ class ModelResponseIterator:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
if self.chunk_type == "accumulated_json" and self.accumulated_json:
return self.handle_accumulated_json_chunk(chunk="")
if self.accumulated_json:
chunk = self._flush()
if chunk:
return chunk
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")