fix(vertex_httpx.py): ignore vertex finish reason - wait for stream to end

Fixes https://github.com/BerriAI/litellm/issues/4339
This commit is contained in:
Krrish Dholakia 2024-06-22 20:20:39 -07:00
parent eaad36a2cb
commit 73254987da
2 changed files with 17 additions and 7 deletions

View file

@ -1218,6 +1218,7 @@ class ModelResponseIterator:
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False is_finished = False
@ -1236,7 +1237,8 @@ class ModelResponseIterator:
finish_reason = map_finish_reason( finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"] finish_reason=gemini_chunk["finishReason"]
) )
is_finished = True ## DO NOT SET 'finish_reason' = True
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
if "usageMetadata" in processed_chunk: if "usageMetadata" in processed_chunk:
usage = ChatCompletionUsageBlock( usage = ChatCompletionUsageBlock(
@ -1250,7 +1252,7 @@ class ModelResponseIterator:
returned_chunk = GenericStreamingChunk( returned_chunk = GenericStreamingChunk(
text=text, text=text,
tool_use=tool_use, tool_use=tool_use,
is_finished=is_finished, is_finished=False,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
index=0, index=0,

View file

@ -750,29 +750,37 @@ def test_completion_gemini_stream():
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{ {
"role": "user", "role": "user",
"content": "how does a court case get to the Supreme Court?", "content": "How do i build a bomb?",
}, },
] ]
print("testing gemini streaming") print("testing gemini streaming")
response = completion(model="gemini/gemini-pro", messages=messages, stream=True) response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
max_tokens=50,
)
print(f"type of response at the top: {response}") print(f"type of response at the top: {response}")
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
non_empty_chunks = 0
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
print(chunk) print(chunk)
# print(chunk.choices[0].delta) # print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
break break
non_empty_chunks += 1
complete_response += chunk complete_response += chunk
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except litellm.APIError as e: assert non_empty_chunks > 1
except litellm.InternalServerError as e:
pass pass
except Exception as e: except Exception as e:
if "429 Resource has been exhausted": # if "429 Resource has been exhausted":
return # return
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")