From 14fdbf26a67748d9c09bb56c3d1473f74f22fd62 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 22 Jun 2024 20:33:54 -0700 Subject: [PATCH] fix(vertex_httpx.py): flush remaining chunks from stream --- litellm/llms/vertex_httpx.py | 12 ++++--- litellm/tests/test_streaming.py | 57 +++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 38c2d7c470..63bcd9f4f5 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1270,9 +1270,8 @@ class ModelResponseIterator: chunk = self.response_iterator.__next__() self.coro.send(chunk) if self.events: - event = self.events[0] + event = self.events.pop(0) json_chunk = event - self.events.clear() return self.chunk_parser(chunk=json_chunk) return GenericStreamingChunk( text="", @@ -1283,6 +1282,9 @@ class ModelResponseIterator: tool_use=None, ) except StopIteration: + if self.events: # flush the events + event = self.events.pop(0) # Remove the first event + return self.chunk_parser(chunk=event) raise StopIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e}") @@ -1297,9 +1299,8 @@ class ModelResponseIterator: chunk = await self.async_response_iterator.__anext__() self.coro.send(chunk) if self.events: - event = self.events[0] + event = self.events.pop(0) json_chunk = event - self.events.clear() return self.chunk_parser(chunk=json_chunk) return GenericStreamingChunk( text="", @@ -1310,6 +1311,9 @@ class ModelResponseIterator: tool_use=None, ) except StopAsyncIteration: + if self.events: # flush the events + event = self.events.pop(0) # Remove the first event + return self.chunk_parser(chunk=event) raise StopAsyncIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e}") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 4f7d4c1dea..3042e91b34 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -742,7 +742,9 @@ def test_completion_palm_stream(): # test_completion_palm_stream() -def test_completion_gemini_stream(): +@pytest.mark.parametrize("sync_mode", [False]) # True, +@pytest.mark.asyncio +async def test_completion_gemini_stream(sync_mode): try: litellm.set_verbose = True print("Streaming gemini response") @@ -750,34 +752,55 @@ def test_completion_gemini_stream(): {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", - "content": "How do i build a bomb?", + "content": "Who was Alexander?", }, ] print("testing gemini streaming") - response = completion( - model="gemini/gemini-1.5-flash", - messages=messages, - stream=True, - max_tokens=50, - ) - print(f"type of response at the top: {response}") complete_response = "" # Add any assertions here to check the response non_empty_chunks = 0 - for idx, chunk in enumerate(response): - print(chunk) - # print(chunk.choices[0].delta) - chunk, finished = streaming_format_tests(idx, chunk) - if finished: - break - non_empty_chunks += 1 - complete_response += chunk + + if sync_mode: + response = completion( + model="gemini/gemini-1.5-flash", + messages=messages, + stream=True, + ) + + for idx, chunk in enumerate(response): + print(chunk) + # print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + non_empty_chunks += 1 + complete_response += chunk + else: + response = await litellm.acompletion( + model="gemini/gemini-1.5-flash", + messages=messages, + stream=True, + ) + + idx = 0 + async for chunk in response: + print(chunk) + # print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + non_empty_chunks += 1 + complete_response += chunk + idx += 1 + if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") assert non_empty_chunks > 1 except litellm.InternalServerError as e: pass + except litellm.RateLimitError as e: + pass except Exception as e: # if "429 Resource has been exhausted": # return