fix(utils.py): fix token_counter to handle empty tool calls in messages

Fixes https://github.com/BerriAI/litellm/pull/4749
This commit is contained in:
Krrish Dholakia 2024-07-19 19:39:00 -07:00
parent b838ff22d5
commit 36ed00ec77
2 changed files with 23 additions and 7 deletions

View file

@ -20,7 +20,12 @@ from litellm import (
token_counter,
)
from litellm.tests.large_text import text
from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS
from litellm.tests.messages_with_counts import (
MESSAGES_TEXT,
MESSAGES_WITH_IMAGES,
MESSAGES_WITH_TOOLS,
)
def test_token_counter_normal_plus_function_calling():
try:
@ -55,27 +60,28 @@ def test_token_counter_normal_plus_function_calling():
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_token_counter_normal_plus_function_calling()
@pytest.mark.parametrize(
"message_count_pair",
MESSAGES_TEXT,
)
def test_token_counter_textonly(message_count_pair):
counted_tokens = token_counter(
model="gpt-35-turbo",
messages=[message_count_pair["message"]]
model="gpt-35-turbo", messages=[message_count_pair["message"]]
)
assert counted_tokens == message_count_pair["count"]
@pytest.mark.parametrize(
"message_count_pair",
MESSAGES_WITH_IMAGES,
)
def test_token_counter_with_images(message_count_pair):
counted_tokens = token_counter(
model="gpt-4o",
messages=[message_count_pair["message"]]
model="gpt-4o", messages=[message_count_pair["message"]]
)
assert counted_tokens == message_count_pair["count"]
@ -327,3 +333,13 @@ def test_get_modified_max_tokens(
), "Got={}, Expected={}, Params={}".format(
calculated_value, expected_value, args
)
def test_empty_tools():
messages = [{"role": "user", "content": "hey, how's it going?", "tool_calls": None}]
result = token_counter(
messages=messages,
)
print(result)

View file

@ -1911,7 +1911,7 @@ def token_counter(
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
is_tool_call = False
num_tokens = 0
if text == None:
if text is None:
if messages is not None:
print_verbose(f"token_counter messages received: {messages}")
text = ""
@ -1937,7 +1937,7 @@ def token_counter(
num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto"
)
if "tool_calls" in message:
if message.get("tool_calls"):
is_tool_call = True
for tool_call in message["tool_calls"]:
if "function" in tool_call: