From b6bc75e27a9abde10cc7f30ee2a63bf2d51784a7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 29 Nov 2023 10:56:21 -0800 Subject: [PATCH] fix(utils.py): fix parallel tool calling when streaming --- litellm/main.py | 21 +++++- litellm/tests/test_function_calling.py | 89 +++++++++++++------------- litellm/utils.py | 15 +++-- 3 files changed, 74 insertions(+), 51 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 6a665e4a90..99a3ccf3f9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2125,6 +2125,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): id = None name = None type = None + tool_calls_list = [] + prev_index = 0 + prev_id = None + curr_id = None + curr_index = 0 for chunk in chunks: choices = chunk["choices"] for choice in choices: @@ -2134,6 +2139,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): if tool_calls and tool_calls[0].function is not None: if tool_calls[0].id: id = tool_calls[0].id + curr_id = id + if prev_id is None: + prev_id = curr_id + if tool_calls[0].index: + curr_index = tool_calls[0].index if tool_calls[0].function.arguments: # Now, tool_calls is expected to be a dictionary arguments = tool_calls[0].function.arguments @@ -2142,10 +2152,17 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): name = tool_calls[0].function.name if tool_calls[0].type: type = tool_calls[0].type + if curr_index != prev_index: # new tool call + combined_arguments = "".join(argument_list) + tool_calls_list.append({"id": prev_id, "index": prev_index, "function": {"arguments": combined_arguments, "name": name}, "type": type}) + argument_list = [] # reset + prev_index = curr_index + prev_id = curr_id combined_arguments = "".join(argument_list) - response["choices"][0]["message"]["content"] = None - response["choices"][0]["message"]["tool_calls"] = [{"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}] + tool_calls_list.append({"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}) + response["choices"][0]["message"]["content"] = None + response["choices"][0]["message"]["tool_calls"] = tool_calls_list elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None: argument_list = [] delta = chunks[0]["choices"][0]["delta"] diff --git a/litellm/tests/test_function_calling.py b/litellm/tests/test_function_calling.py index 67474a3a6c..a7f0225d4f 100644 --- a/litellm/tests/test_function_calling.py +++ b/litellm/tests/test_function_calling.py @@ -13,7 +13,7 @@ import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError import pytest -litellm.num_retries = 3 +litellm.num_retries = 0 litellm.cache = None # litellm.set_verbose=True import json @@ -97,6 +97,7 @@ def test_parallel_function_call(): "content": function_response, } ) # extend conversation with function response + print(f"messages: {messages}") second_response = litellm.completion( model="gpt-3.5-turbo-1106", messages=messages, @@ -108,7 +109,7 @@ def test_parallel_function_call(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_parallel_function_call() +test_parallel_function_call() @@ -143,51 +144,53 @@ def test_parallel_function_call_stream(): tools=tools, stream=True, tool_choice="auto", # auto is default, but we'll be explicit + complete_response = True ) print("Response\n", response) - for chunk in response: - print(chunk) - # response_message = response.choices[0].message - # tool_calls = response_message.tool_calls + # for chunk in response: + # print(chunk) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls - # print("length of tool calls", len(tool_calls)) - # print("Expecting there to be 3 tool calls") - # assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise + print("length of tool calls", len(tool_calls)) + print("Expecting there to be 3 tool calls") + assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise - # # Step 2: check if the model wanted to call a function - # if tool_calls: - # # Step 3: call the function - # # Note: the JSON response may not always be valid; be sure to handle errors - # available_functions = { - # "get_current_weather": get_current_weather, - # } # only one function in this example, but you can have multiple - # messages.append(response_message) # extend conversation with assistant's reply - # print("Response message\n", response_message) - # # Step 4: send the info for each function call and function response to the model - # for tool_call in tool_calls: - # function_name = tool_call.function.name - # function_to_call = available_functions[function_name] - # function_args = json.loads(tool_call.function.arguments) - # function_response = function_to_call( - # location=function_args.get("location"), - # unit=function_args.get("unit"), - # ) - # messages.append( - # { - # "tool_call_id": tool_call.id, - # "role": "tool", - # "name": function_name, - # "content": function_response, - # } - # ) # extend conversation with function response - # second_response = litellm.completion( - # model="gpt-3.5-turbo-1106", - # messages=messages, - # temperature=0.2, - # seed=22 - # ) # get a new response from the model where it can see the function response - # print("second response\n", second_response) - # return second_response + # Step 2: check if the model wanted to call a function + if tool_calls: + # Step 3: call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = { + "get_current_weather": get_current_weather, + } # only one function in this example, but you can have multiple + messages.append(response_message) # extend conversation with assistant's reply + print("Response message\n", response_message) + # Step 4: send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + location=function_args.get("location"), + unit=function_args.get("unit"), + ) + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) # extend conversation with function response + print(f"messages: {messages}") + second_response = litellm.completion( + model="gpt-3.5-turbo-1106", + messages=messages, + temperature=0.2, + seed=22 + ) # get a new response from the model where it can see the function response + print("second response\n", second_response) + return second_response except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 95fa27439e..c63aae07b4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5247,11 +5247,14 @@ class CustomStreamWrapper: original_chunk = response_obj.get("original_chunk", None) model_response.id = original_chunk.id if len(original_chunk.choices) > 0: - try: - delta = dict(original_chunk.choices[0].delta) - model_response.choices[0].delta = Delta(**delta) - except Exception as e: - model_response.choices[0].delta = Delta() + if original_chunk.choices[0].delta.function_call is not None or original_chunk.choices[0].delta.tool_calls is not None: + try: + delta = dict(original_chunk.choices[0].delta) + model_response.choices[0].delta = Delta(**delta) + except Exception as e: + model_response.choices[0].delta = Delta() + else: + return else: return model_response.system_fingerprint = original_chunk.system_fingerprint @@ -5284,7 +5287,7 @@ class CustomStreamWrapper: chunk = self.completion_stream else: chunk = next(self.completion_stream) - + if chunk is not None and chunk != b'': response = self.chunk_creator(chunk=chunk) if response is not None: