From 6fda0365b880290d0eb8ec9e87002a7f0ddb5508 Mon Sep 17 00:00:00 2001 From: Mathis Beer Date: Tue, 4 Mar 2025 16:31:50 +0100 Subject: [PATCH] Refactor Gemini stream parser logic: remove duplicated "single line/multi line" logic, just treat it as multiline every time. --- .../vertex_and_google_ai_studio_gemini.py | 89 ++++++------------- 1 file changed, 25 insertions(+), 64 deletions(-) diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 294939a3c5..f4d09a514c 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -1312,7 +1312,6 @@ class VertexLLM(VertexBase): class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response - self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json" self.accumulated_json = "" self.sent_first_chunk = False @@ -1389,78 +1388,38 @@ class ModelResponseIterator: self.response_iterator = self.streaming_response return self - def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk: - chunk = chunk.strip() - try: - json_chunk = json.loads(chunk) - - except json.JSONDecodeError as e: - if ( - self.sent_first_chunk is False - ): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked. - self.chunk_type = "accumulated_json" - return self.handle_accumulated_json_chunk(chunk=chunk) - raise e - - if self.sent_first_chunk is False: - self.sent_first_chunk = True - - return self.chunk_parser(chunk=json_chunk) - - def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk: - chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" - message = chunk.replace("\n\n", "") - - # Accumulate JSON data - self.accumulated_json += message - - # Try to parse the accumulated JSON + def _flush(self) -> Optional[GenericStreamingChunk]: try: _data = json.loads(self.accumulated_json) - self.accumulated_json = "" # reset after successful parsing + self.accumulated_json = "" return self.chunk_parser(chunk=_data) except json.JSONDecodeError: - # If it's not valid JSON yet, continue to the next event - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) + return None def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: - try: - chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" - if len(chunk) > 0: - """ - Check if initial chunk valid json - - if partial json -> enter accumulated json logic - - if valid - continue - """ - if self.chunk_type == "valid_json": - return self.handle_valid_json_chunk(chunk=chunk) - elif self.chunk_type == "accumulated_json": - return self.handle_accumulated_json_chunk(chunk=chunk) - - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) - except Exception: - raise + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + self.accumulated_json += chunk + chunk = self._flush() + if chunk: + return chunk + # If it's not valid JSON yet, continue to the next event + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) def __next__(self): try: chunk = self.response_iterator.__next__() except StopIteration: - if self.chunk_type == "accumulated_json" and self.accumulated_json: - return self.handle_accumulated_json_chunk(chunk="") + if self.accumulated_json: + chunk = self._flush() + if chunk: + return chunk raise StopIteration except ValueError as e: raise RuntimeError(f"Error receiving chunk from stream: {e}") @@ -1481,8 +1440,10 @@ class ModelResponseIterator: try: chunk = await self.async_response_iterator.__anext__() except StopAsyncIteration: - if self.chunk_type == "accumulated_json" and self.accumulated_json: - return self.handle_accumulated_json_chunk(chunk="") + if self.accumulated_json: + chunk = self._flush() + if chunk: + return chunk raise StopAsyncIteration except ValueError as e: raise RuntimeError(f"Error receiving chunk from stream: {e}")