fix(utils.py): handle token counter error when invalid message passed in (#8670)

* fix(utils.py): handle token counter error

* fix(utils.py): testing fixes

* fix(utils.py): fix incr for num tokens from list

* fix(utils.py): fix text str token counting
This commit is contained in:
Krish Dholakia 2025-02-19 22:21:34 -08:00 committed by GitHub
parent 982ee4b96b
commit f9df01fbc6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 121 additions and 47 deletions

View file

@ -1553,6 +1553,7 @@ def openai_token_counter( # noqa: PLR0915
bool
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
):
"""
Return the number of tokens used by a list of messages.
@ -1600,31 +1601,12 @@ def openai_token_counter( # noqa: PLR0915
if key == "name":
num_tokens += tokens_per_name
elif isinstance(value, List):
for c in value:
if c["type"] == "text":
text += c["text"]
num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
text, num_tokens_from_list = _get_num_tokens_from_content_list(
content_list=value,
use_default_image_token_count=use_default_image_token_count,
default_token_count=default_token_count,
)
num_tokens += num_tokens_from_list
elif text is not None and count_response_tokens is True:
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
num_tokens = len(encoding.encode(text, disallowed_special=()))
@ -1759,6 +1741,52 @@ def _format_type(props, indent):
return "any"
def _get_num_tokens_from_content_list(
content_list: List[Dict[str, Any]],
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
) -> Tuple[str, int]:
"""
Get the number of tokens from a list of content.
Returns:
Tuple[str, int]: A tuple containing the text and the number of tokens.
"""
try:
num_tokens = 0
text = ""
for c in content_list:
if c["type"] == "text":
text += c["text"]
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
return text, num_tokens
except Exception as e:
if default_token_count is not None:
return "", default_token_count
raise ValueError(
f"Error getting number of tokens from content list: {e}, default_token_count={default_token_count}"
)
def token_counter(
model="",
custom_tokenizer: Optional[Union[dict, SelectTokenizerResponse]] = None,
@ -1768,6 +1796,7 @@ def token_counter(
tools: Optional[List[ChatCompletionToolParam]] = None,
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
) -> int:
"""
Count the number of tokens in a given text using a specified model.
@ -1777,6 +1806,7 @@ def token_counter(
custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None.
text (str): The raw text string to be passed to the model. Default is None.
messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None.
default_token_count (Optional[int]): The default number of tokens to return for a message block, if an error occurs. Default is None.
Returns:
int: The number of tokens in the text.
@ -1794,28 +1824,11 @@ def token_counter(
if isinstance(content, str):
text += message["content"]
elif isinstance(content, List):
for c in content:
if c["type"] == "text":
text += c["text"]
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
text, num_tokens = _get_num_tokens_from_content_list(
content_list=content,
use_default_image_token_count=use_default_image_token_count,
default_token_count=default_token_count,
)
if message.get("tool_calls"):
is_tool_call = True
for tool_call in message["tool_calls"]:
@ -1859,6 +1872,7 @@ def token_counter(
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
default_token_count=default_token_count,
)
else:
print_verbose(
@ -1874,6 +1888,7 @@ def token_counter(
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
default_token_count=default_token_count,
)
else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore

View file

@ -470,3 +470,62 @@ class TestTokenizerSelection(unittest.TestCase):
mock_return_huggingface_tokenizer.assert_not_called()
assert result["type"] == "openai_tokenizer"
assert result["tokenizer"] == encoding
@pytest.mark.parametrize(
"model",
[
"gpt-4o",
"claude-3-opus-20240229",
],
)
@pytest.mark.parametrize(
"messages",
[
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "These are some sample images from a movie. Based on these images, what do you think the tone of the movie is?",
},
{
"type": "text",
"image_url": {
"url": "https://gratisography.com/wp-content/uploads/2024/11/gratisography-augmented-reality-800x525.jpg",
"detail": "high",
},
},
],
}
],
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "These are some sample images from a movie. Based on these images, what do you think the tone of the movie is?",
},
{
"type": "text",
"image_url": {
"url": "https://gratisography.com/wp-content/uploads/2024/11/gratisography-augmented-reality-800x525.jpg",
"detail": "high",
},
},
],
}
],
],
)
def test_bad_input_token_counter(model, messages):
"""
Safely handle bad input for token counter.
"""
token_counter(
model=model,
messages=messages,
default_token_count=1000,
)