fix(utils.py): fix parallel tool calling when streaming

This commit is contained in:
Krrish Dholakia 2023-11-29 10:56:21 -08:00
parent 9024a47dc2
commit b6bc75e27a
3 changed files with 74 additions and 51 deletions

View file

@ -2125,6 +2125,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
id = None
name = None
type = None
tool_calls_list = []
prev_index = 0
prev_id = None
curr_id = None
curr_index = 0
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
@ -2134,6 +2139,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
@ -2142,10 +2152,17 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append({"id": prev_id, "index": prev_index, "function": {"arguments": combined_arguments, "name": name}, "type": type})
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
combined_arguments = "".join(argument_list)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = [{"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}]
tool_calls_list.append({"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type})
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None:
argument_list = []
delta = chunks[0]["choices"][0]["delta"]