fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini

This commit is contained in:
Krrish Dholakia 2024-07-06 14:02:25 -07:00
parent 2452753e08
commit faa88a1ab1
6 changed files with 111 additions and 16 deletions

View file

@ -7950,6 +7950,7 @@ class CustomStreamWrapper:
)
self.messages = getattr(logging_obj, "messages", None)
self.sent_stream_usage = False
self.tool_call = False
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
@ -9192,9 +9193,16 @@ class CustomStreamWrapper:
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
"original_chunk": chunk,
"tool_calls": (
chunk.choices[0].delta.tool_calls
if hasattr(chunk.choices[0].delta, "tool_calls")
else None
),
}
completion_obj["content"] = response_obj["text"]
if response_obj["tool_calls"] is not None:
completion_obj["tool_calls"] = response_obj["tool_calls"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if hasattr(chunk, "id"):
model_response.id = chunk.id
@ -9352,6 +9360,10 @@ class CustomStreamWrapper:
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
self.tool_call = True
## RETURN ARG
if (
"content" in completion_obj
@ -9530,6 +9542,12 @@ class CustomStreamWrapper:
)
else:
model_response.choices[0].finish_reason = "stop"
## if tool use
if (
model_response.choices[0].finish_reason == "stop" and self.tool_call
): # don't overwrite for other - potential error finish reasons
model_response.choices[0].finish_reason = "tool_calls"
return model_response
def __next__(self):
@ -9583,7 +9601,7 @@ class CustomStreamWrapper:
return response
except StopIteration:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
if (
self.sent_stream_usage == False
and self.stream_options is not None