From 4cdd930fa216851a5440d3539570517558e1184e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 27 Nov 2023 18:39:10 -0800 Subject: [PATCH] fix(stream_chunk_builder): adding support for tool calling in completion counting --- litellm/main.py | 41 ++++++++++-- litellm/tests/test_stream_chunk_builder.py | 74 ++++++++++++++++++++-- litellm/utils.py | 4 ++ 3 files changed, 109 insertions(+), 10 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 64fe498d97..d4a3826bb6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2109,7 +2109,36 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): content_list = [] combined_content = "" - if "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None: + if "tool_calls" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None: + argument_list = [] + delta = chunks[0]["choices"][0]["delta"] + message = response["choices"][0]["message"] + message["tool_calls"] = [] + id = None + name = None + type = None + for chunk in chunks: + choices = chunk["choices"] + for choice in choices: + delta = choice.get("delta", {}) + tool_calls = delta.get("tool_calls", "") + # Check if a tool call is present + if tool_calls and tool_calls[0].function is not None: + if tool_calls[0].id: + id = tool_calls[0].id + if tool_calls[0].function.arguments: + # Now, tool_calls is expected to be a dictionary + arguments = tool_calls[0].function.arguments + argument_list.append(arguments) + if tool_calls[0].function.name: + name = tool_calls[0].function.name + if tool_calls[0].type: + type = tool_calls[0].type + + combined_arguments = "".join(argument_list) + response["choices"][0]["message"]["content"] = None + response["choices"][0]["message"]["tool_calls"] = [{"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}] + elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None: argument_list = [] delta = chunks[0]["choices"][0]["delta"] function_call = delta.get("function_call", "") @@ -2144,16 +2173,20 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): continue # openai v1.0.0 sets content = None for chunks content_list.append(content) - # Combine the "content" strings into a single string - combined_content = "".join(content_list) + # Combine the "content" strings into a single string || combine the 'function' strings into a single string + combined_content = "".join(combined_arguments) # Update the "content" field within the response dictionary response["choices"][0]["message"]["content"] = combined_content + if len(combined_content) > 0: + completion_output = combined_content + elif len(combined_arguments) > 0: + completion_output = combined_arguments # # Update usage information if needed if messages: response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages) - response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content) + response["usage"]["completion_tokens"] = token_counter(model=model, text=completion_output) response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index bad6f7f735..5907d37d06 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -7,6 +7,7 @@ sys.path.insert( from litellm import completion, stream_chunk_builder import litellm import os, dotenv +from openai import OpenAI import pytest dotenv.load_dotenv() @@ -30,20 +31,81 @@ function_schema = { }, } -def test_stream_chunk_builder(): + +tools_schema = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ] + +# def test_stream_chunk_builder_tools(): +# try: +# litellm.set_verbose = False +# response = client.chat.completions.create( +# model="gpt-3.5-turbo", +# messages=messages, +# tools=tools_schema, +# # stream=True, +# # complete_response=True # runs stream_chunk_builder under-the-hood +# ) + +# print(f"response: {response}") +# print(f"response usage: {response.usage}") +# except Exception as e: +# pytest.fail(f"An exception occurred - {str(e)}") + +# test_stream_chunk_builder_tools() + +def test_stream_chunk_builder_litellm_function_call(): try: litellm.set_verbose = False - response = completion( + response = litellm.completion( model="gpt-3.5-turbo", messages=messages, functions=[function_schema], - stream=True, - complete_response=True # runs stream_chunk_builder under-the-hood + # stream=True, + # complete_response=True # runs stream_chunk_builder under-the-hood ) print(f"response: {response}") - print(f"response usage: {response['usage']}") except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") -test_stream_chunk_builder() \ No newline at end of file +# test_stream_chunk_builder_litellm_function_call() + +def test_stream_chunk_builder_litellm_tool_call(): + try: + litellm.set_verbose = False + response = litellm.completion( + model="gpt-3.5-turbo", + messages=messages, + tools=tools_schema, + stream=True, + complete_response = True + ) + + print(f"complete response: {response}") + print(f"complete response usage: {response.usage}") + + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + +test_stream_chunk_builder_litellm_tool_call() diff --git a/litellm/utils.py b/litellm/utils.py index 323b71a707..eb107692c6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -140,12 +140,16 @@ class Message(OpenAIObject): self.role = role if function_call is not None: self.function_call = FunctionCall(**function_call) + else: + self.function_call = None if tool_calls is not None: self.tool_calls = [] for tool_call in tool_calls: self.tool_calls.append( ChatCompletionMessageToolCall(**tool_call) ) + else: + self.tool_calls = None if logprobs is not None: self._logprobs = logprobs