fix(vertex_httpx.py): flush remaining chunks from stream

This commit is contained in:
Krrish Dholakia 2024-06-22 20:33:54 -07:00
parent 2d8135231f
commit 14fdbf26a6
2 changed files with 48 additions and 21 deletions

View file

@ -1270,9 +1270,8 @@ class ModelResponseIterator:
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
event = self.events.pop(0)
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
@ -1283,6 +1282,9 @@ class ModelResponseIterator:
tool_use=None,
)
except StopIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
@ -1297,9 +1299,8 @@ class ModelResponseIterator:
chunk = await self.async_response_iterator.__anext__()
self.coro.send(chunk)
if self.events:
event = self.events[0]
event = self.events.pop(0)
json_chunk = event
self.events.clear()
return self.chunk_parser(chunk=json_chunk)
return GenericStreamingChunk(
text="",
@ -1310,6 +1311,9 @@ class ModelResponseIterator:
tool_use=None,
)
except StopAsyncIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")

View file

@ -742,7 +742,9 @@ def test_completion_palm_stream():
# test_completion_palm_stream()
def test_completion_gemini_stream():
@pytest.mark.parametrize("sync_mode", [False]) # True,
@pytest.mark.asyncio
async def test_completion_gemini_stream(sync_mode):
try:
litellm.set_verbose = True
print("Streaming gemini response")
@ -750,20 +752,21 @@ def test_completion_gemini_stream():
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "How do i build a bomb?",
"content": "Who was Alexander?",
},
]
print("testing gemini streaming")
complete_response = ""
# Add any assertions here to check the response
non_empty_chunks = 0
if sync_mode:
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
max_tokens=50,
)
print(f"type of response at the top: {response}")
complete_response = ""
# Add any assertions here to check the response
non_empty_chunks = 0
for idx, chunk in enumerate(response):
print(chunk)
# print(chunk.choices[0].delta)
@ -772,12 +775,32 @@ def test_completion_gemini_stream():
break
non_empty_chunks += 1
complete_response += chunk
else:
response = await litellm.acompletion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
)
idx = 0
async for chunk in response:
print(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
non_empty_chunks += 1
complete_response += chunk
idx += 1
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
assert non_empty_chunks > 1
except litellm.InternalServerError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
# if "429 Resource has been exhausted":
# return