forked from phoenix/litellm-mirror
fix(utils.py): fix token_counter to handle empty tool calls in messages
Fixes https://github.com/BerriAI/litellm/pull/4749
This commit is contained in:
parent
e2d275f1b7
commit
95a0f6839f
2 changed files with 23 additions and 7 deletions
|
@ -20,7 +20,12 @@ from litellm import (
|
||||||
token_counter,
|
token_counter,
|
||||||
)
|
)
|
||||||
from litellm.tests.large_text import text
|
from litellm.tests.large_text import text
|
||||||
from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS
|
from litellm.tests.messages_with_counts import (
|
||||||
|
MESSAGES_TEXT,
|
||||||
|
MESSAGES_WITH_IMAGES,
|
||||||
|
MESSAGES_WITH_TOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_token_counter_normal_plus_function_calling():
|
def test_token_counter_normal_plus_function_calling():
|
||||||
try:
|
try:
|
||||||
|
@ -55,27 +60,28 @@ def test_token_counter_normal_plus_function_calling():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_token_counter_normal_plus_function_calling()
|
# test_token_counter_normal_plus_function_calling()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"message_count_pair",
|
"message_count_pair",
|
||||||
MESSAGES_TEXT,
|
MESSAGES_TEXT,
|
||||||
)
|
)
|
||||||
def test_token_counter_textonly(message_count_pair):
|
def test_token_counter_textonly(message_count_pair):
|
||||||
counted_tokens = token_counter(
|
counted_tokens = token_counter(
|
||||||
model="gpt-35-turbo",
|
model="gpt-35-turbo", messages=[message_count_pair["message"]]
|
||||||
messages=[message_count_pair["message"]]
|
|
||||||
)
|
)
|
||||||
assert counted_tokens == message_count_pair["count"]
|
assert counted_tokens == message_count_pair["count"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"message_count_pair",
|
"message_count_pair",
|
||||||
MESSAGES_WITH_IMAGES,
|
MESSAGES_WITH_IMAGES,
|
||||||
)
|
)
|
||||||
def test_token_counter_with_images(message_count_pair):
|
def test_token_counter_with_images(message_count_pair):
|
||||||
counted_tokens = token_counter(
|
counted_tokens = token_counter(
|
||||||
model="gpt-4o",
|
model="gpt-4o", messages=[message_count_pair["message"]]
|
||||||
messages=[message_count_pair["message"]]
|
|
||||||
)
|
)
|
||||||
assert counted_tokens == message_count_pair["count"]
|
assert counted_tokens == message_count_pair["count"]
|
||||||
|
|
||||||
|
@ -327,3 +333,13 @@ def test_get_modified_max_tokens(
|
||||||
), "Got={}, Expected={}, Params={}".format(
|
), "Got={}, Expected={}, Params={}".format(
|
||||||
calculated_value, expected_value, args
|
calculated_value, expected_value, args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_tools():
|
||||||
|
messages = [{"role": "user", "content": "hey, how's it going?", "tool_calls": None}]
|
||||||
|
|
||||||
|
result = token_counter(
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
|
|
@ -1911,7 +1911,7 @@ def token_counter(
|
||||||
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
|
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
|
||||||
is_tool_call = False
|
is_tool_call = False
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
if text == None:
|
if text is None:
|
||||||
if messages is not None:
|
if messages is not None:
|
||||||
print_verbose(f"token_counter messages received: {messages}")
|
print_verbose(f"token_counter messages received: {messages}")
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -1937,7 +1937,7 @@ def token_counter(
|
||||||
num_tokens += calculage_img_tokens(
|
num_tokens += calculage_img_tokens(
|
||||||
data=image_url_str, mode="auto"
|
data=image_url_str, mode="auto"
|
||||||
)
|
)
|
||||||
if "tool_calls" in message:
|
if message.get("tool_calls"):
|
||||||
is_tool_call = True
|
is_tool_call = True
|
||||||
for tool_call in message["tool_calls"]:
|
for tool_call in message["tool_calls"]:
|
||||||
if "function" in tool_call:
|
if "function" in tool_call:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue