From 36ed00ec77c566628a053fc00aa137ea37df07e7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Jul 2024 19:39:00 -0700 Subject: [PATCH] fix(utils.py): fix token_counter to handle empty tool calls in messages Fixes https://github.com/BerriAI/litellm/pull/4749 --- litellm/tests/test_token_counter.py | 26 +++++++++++++++++++++----- litellm/utils.py | 4 ++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index 59d908afea..6bd001fcc8 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -20,7 +20,12 @@ from litellm import ( token_counter, ) from litellm.tests.large_text import text -from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS +from litellm.tests.messages_with_counts import ( + MESSAGES_TEXT, + MESSAGES_WITH_IMAGES, + MESSAGES_WITH_TOOLS, +) + def test_token_counter_normal_plus_function_calling(): try: @@ -55,27 +60,28 @@ def test_token_counter_normal_plus_function_calling(): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + # test_token_counter_normal_plus_function_calling() + @pytest.mark.parametrize( "message_count_pair", MESSAGES_TEXT, ) def test_token_counter_textonly(message_count_pair): counted_tokens = token_counter( - model="gpt-35-turbo", - messages=[message_count_pair["message"]] + model="gpt-35-turbo", messages=[message_count_pair["message"]] ) assert counted_tokens == message_count_pair["count"] + @pytest.mark.parametrize( "message_count_pair", MESSAGES_WITH_IMAGES, ) def test_token_counter_with_images(message_count_pair): counted_tokens = token_counter( - model="gpt-4o", - messages=[message_count_pair["message"]] + model="gpt-4o", messages=[message_count_pair["message"]] ) assert counted_tokens == message_count_pair["count"] @@ -327,3 +333,13 @@ def test_get_modified_max_tokens( ), "Got={}, Expected={}, Params={}".format( calculated_value, expected_value, args ) + + +def test_empty_tools(): + messages = [{"role": "user", "content": "hey, how's it going?", "tool_calls": None}] + + result = token_counter( + messages=messages, + ) + + print(result) diff --git a/litellm/utils.py b/litellm/utils.py index 809613a091..f57317a3fd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1911,7 +1911,7 @@ def token_counter( # use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model is_tool_call = False num_tokens = 0 - if text == None: + if text is None: if messages is not None: print_verbose(f"token_counter messages received: {messages}") text = "" @@ -1937,7 +1937,7 @@ def token_counter( num_tokens += calculage_img_tokens( data=image_url_str, mode="auto" ) - if "tool_calls" in message: + if message.get("tool_calls"): is_tool_call = True for tool_call in message["tool_calls"]: if "function" in tool_call: