fix(utils.py): fix token_counter to handle empty tool calls in messages

Fixes https://github.com/BerriAI/litellm/pull/4749
This commit is contained in:
Krrish Dholakia 2024-07-19 19:39:00 -07:00
parent e2d275f1b7
commit 95a0f6839f
2 changed files with 23 additions and 7 deletions

View file

@ -20,7 +20,12 @@ from litellm import (
token_counter, token_counter,
) )
from litellm.tests.large_text import text from litellm.tests.large_text import text
from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS from litellm.tests.messages_with_counts import (
MESSAGES_TEXT,
MESSAGES_WITH_IMAGES,
MESSAGES_WITH_TOOLS,
)
def test_token_counter_normal_plus_function_calling(): def test_token_counter_normal_plus_function_calling():
try: try:
@ -55,27 +60,28 @@ def test_token_counter_normal_plus_function_calling():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# test_token_counter_normal_plus_function_calling() # test_token_counter_normal_plus_function_calling()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"message_count_pair", "message_count_pair",
MESSAGES_TEXT, MESSAGES_TEXT,
) )
def test_token_counter_textonly(message_count_pair): def test_token_counter_textonly(message_count_pair):
counted_tokens = token_counter( counted_tokens = token_counter(
model="gpt-35-turbo", model="gpt-35-turbo", messages=[message_count_pair["message"]]
messages=[message_count_pair["message"]]
) )
assert counted_tokens == message_count_pair["count"] assert counted_tokens == message_count_pair["count"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"message_count_pair", "message_count_pair",
MESSAGES_WITH_IMAGES, MESSAGES_WITH_IMAGES,
) )
def test_token_counter_with_images(message_count_pair): def test_token_counter_with_images(message_count_pair):
counted_tokens = token_counter( counted_tokens = token_counter(
model="gpt-4o", model="gpt-4o", messages=[message_count_pair["message"]]
messages=[message_count_pair["message"]]
) )
assert counted_tokens == message_count_pair["count"] assert counted_tokens == message_count_pair["count"]
@ -327,3 +333,13 @@ def test_get_modified_max_tokens(
), "Got={}, Expected={}, Params={}".format( ), "Got={}, Expected={}, Params={}".format(
calculated_value, expected_value, args calculated_value, expected_value, args
) )
def test_empty_tools():
messages = [{"role": "user", "content": "hey, how's it going?", "tool_calls": None}]
result = token_counter(
messages=messages,
)
print(result)

View file

@ -1911,7 +1911,7 @@ def token_counter(
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model # use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
is_tool_call = False is_tool_call = False
num_tokens = 0 num_tokens = 0
if text == None: if text is None:
if messages is not None: if messages is not None:
print_verbose(f"token_counter messages received: {messages}") print_verbose(f"token_counter messages received: {messages}")
text = "" text = ""
@ -1937,7 +1937,7 @@ def token_counter(
num_tokens += calculage_img_tokens( num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto" data=image_url_str, mode="auto"
) )
if "tool_calls" in message: if message.get("tool_calls"):
is_tool_call = True is_tool_call = True
for tool_call in message["tool_calls"]: for tool_call in message["tool_calls"]:
if "function" in tool_call: if "function" in tool_call: