fix(gemini.py): fix async streaming + add native async completions

This commit is contained in:
Krrish Dholakia 2024-02-19 22:41:36 -08:00
parent 45326c93dc
commit 45eb4a5fcc
6 changed files with 224 additions and 17 deletions

View file

@ -8417,7 +8417,28 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider == "gemini":
completion_obj["content"] = chunk.text
try:
if hasattr(chunk, "parts") == True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if hasattr(chunk.parts[0], "finish_reason"):
model_response.choices[0].finish_reason = (
map_finish_reason(chunk.parts[0].finish_reason.name)
)
except:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
raise e
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try:
if hasattr(chunk, "candidates") == True:
@ -8727,19 +8748,21 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
print_verbose(
f"value of async completion stream: {self.completion_stream}"
)
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
print_verbose(
f"value of async chunk: {chunk.parts}; len(chunk.parts): {len(chunk.parts)}"
)
if chunk == "None" or chunk is None:
raise Exception
elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0:
continue
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
# __anext__ also calls async_success_handler, which does logging
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
chunk=chunk
)