mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini
This commit is contained in:
parent
2452753e08
commit
faa88a1ab1
6 changed files with 111 additions and 16 deletions
|
@ -7950,6 +7950,7 @@ class CustomStreamWrapper:
|
|||
)
|
||||
self.messages = getattr(logging_obj, "messages", None)
|
||||
self.sent_stream_usage = False
|
||||
self.tool_call = False
|
||||
self.chunks: List = (
|
||||
[]
|
||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||
|
@ -9192,9 +9193,16 @@ class CustomStreamWrapper:
|
|||
"is_finished": True,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
"original_chunk": chunk,
|
||||
"tool_calls": (
|
||||
chunk.choices[0].delta.tool_calls
|
||||
if hasattr(chunk.choices[0].delta, "tool_calls")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["tool_calls"] is not None:
|
||||
completion_obj["tool_calls"] = response_obj["tool_calls"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if hasattr(chunk, "id"):
|
||||
model_response.id = chunk.id
|
||||
|
@ -9352,6 +9360,10 @@ class CustomStreamWrapper:
|
|||
)
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
|
||||
## CHECK FOR TOOL USE
|
||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||
self.tool_call = True
|
||||
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
|
@ -9530,6 +9542,12 @@ class CustomStreamWrapper:
|
|||
)
|
||||
else:
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
|
||||
## if tool use
|
||||
if (
|
||||
model_response.choices[0].finish_reason == "stop" and self.tool_call
|
||||
): # don't overwrite for other - potential error finish reasons
|
||||
model_response.choices[0].finish_reason = "tool_calls"
|
||||
return model_response
|
||||
|
||||
def __next__(self):
|
||||
|
@ -9583,7 +9601,7 @@ class CustomStreamWrapper:
|
|||
return response
|
||||
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
and self.stream_options is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue