mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(utils.py): handle token counter error when invalid message passed in (#8670)
* fix(utils.py): handle token counter error * fix(utils.py): testing fixes * fix(utils.py): fix incr for num tokens from list * fix(utils.py): fix text str token counting
This commit is contained in:
parent
982ee4b96b
commit
f9df01fbc6
2 changed files with 121 additions and 47 deletions
121
litellm/utils.py
121
litellm/utils.py
|
@ -1553,6 +1553,7 @@ def openai_token_counter( # noqa: PLR0915
|
||||||
bool
|
bool
|
||||||
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
||||||
use_default_image_token_count: Optional[bool] = False,
|
use_default_image_token_count: Optional[bool] = False,
|
||||||
|
default_token_count: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return the number of tokens used by a list of messages.
|
Return the number of tokens used by a list of messages.
|
||||||
|
@ -1600,31 +1601,12 @@ def openai_token_counter( # noqa: PLR0915
|
||||||
if key == "name":
|
if key == "name":
|
||||||
num_tokens += tokens_per_name
|
num_tokens += tokens_per_name
|
||||||
elif isinstance(value, List):
|
elif isinstance(value, List):
|
||||||
for c in value:
|
text, num_tokens_from_list = _get_num_tokens_from_content_list(
|
||||||
if c["type"] == "text":
|
content_list=value,
|
||||||
text += c["text"]
|
use_default_image_token_count=use_default_image_token_count,
|
||||||
num_tokens += len(
|
default_token_count=default_token_count,
|
||||||
encoding.encode(c["text"], disallowed_special=())
|
|
||||||
)
|
|
||||||
elif c["type"] == "image_url":
|
|
||||||
if isinstance(c["image_url"], dict):
|
|
||||||
image_url_dict = c["image_url"]
|
|
||||||
detail = image_url_dict.get("detail", "auto")
|
|
||||||
url = image_url_dict.get("url")
|
|
||||||
num_tokens += calculate_img_tokens(
|
|
||||||
data=url,
|
|
||||||
mode=detail,
|
|
||||||
use_default_image_token_count=use_default_image_token_count
|
|
||||||
or False,
|
|
||||||
)
|
|
||||||
elif isinstance(c["image_url"], str):
|
|
||||||
image_url_str = c["image_url"]
|
|
||||||
num_tokens += calculate_img_tokens(
|
|
||||||
data=image_url_str,
|
|
||||||
mode="auto",
|
|
||||||
use_default_image_token_count=use_default_image_token_count
|
|
||||||
or False,
|
|
||||||
)
|
)
|
||||||
|
num_tokens += num_tokens_from_list
|
||||||
elif text is not None and count_response_tokens is True:
|
elif text is not None and count_response_tokens is True:
|
||||||
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
|
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
|
||||||
num_tokens = len(encoding.encode(text, disallowed_special=()))
|
num_tokens = len(encoding.encode(text, disallowed_special=()))
|
||||||
|
@ -1759,44 +1741,24 @@ def _format_type(props, indent):
|
||||||
return "any"
|
return "any"
|
||||||
|
|
||||||
|
|
||||||
def token_counter(
|
def _get_num_tokens_from_content_list(
|
||||||
model="",
|
content_list: List[Dict[str, Any]],
|
||||||
custom_tokenizer: Optional[Union[dict, SelectTokenizerResponse]] = None,
|
|
||||||
text: Optional[Union[str, List[str]]] = None,
|
|
||||||
messages: Optional[List] = None,
|
|
||||||
count_response_tokens: Optional[bool] = False,
|
|
||||||
tools: Optional[List[ChatCompletionToolParam]] = None,
|
|
||||||
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
|
|
||||||
use_default_image_token_count: Optional[bool] = False,
|
use_default_image_token_count: Optional[bool] = False,
|
||||||
) -> int:
|
default_token_count: Optional[int] = None,
|
||||||
|
) -> Tuple[str, int]:
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in a given text using a specified model.
|
Get the number of tokens from a list of content.
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): The name of the model to use for tokenization. Default is an empty string.
|
|
||||||
custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None.
|
|
||||||
text (str): The raw text string to be passed to the model. Default is None.
|
|
||||||
messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The number of tokens in the text.
|
Tuple[str, int]: A tuple containing the text and the number of tokens.
|
||||||
"""
|
"""
|
||||||
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
|
try:
|
||||||
is_tool_call = False
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
if text is None:
|
|
||||||
if messages is not None:
|
|
||||||
print_verbose(f"token_counter messages received: {messages}")
|
|
||||||
text = ""
|
text = ""
|
||||||
for message in messages:
|
for c in content_list:
|
||||||
if message.get("content", None) is not None:
|
|
||||||
content = message.get("content")
|
|
||||||
if isinstance(content, str):
|
|
||||||
text += message["content"]
|
|
||||||
elif isinstance(content, List):
|
|
||||||
for c in content:
|
|
||||||
if c["type"] == "text":
|
if c["type"] == "text":
|
||||||
text += c["text"]
|
text += c["text"]
|
||||||
|
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
|
||||||
elif c["type"] == "image_url":
|
elif c["type"] == "image_url":
|
||||||
if isinstance(c["image_url"], dict):
|
if isinstance(c["image_url"], dict):
|
||||||
image_url_dict = c["image_url"]
|
image_url_dict = c["image_url"]
|
||||||
|
@ -1816,6 +1778,57 @@ def token_counter(
|
||||||
use_default_image_token_count=use_default_image_token_count
|
use_default_image_token_count=use_default_image_token_count
|
||||||
or False,
|
or False,
|
||||||
)
|
)
|
||||||
|
return text, num_tokens
|
||||||
|
except Exception as e:
|
||||||
|
if default_token_count is not None:
|
||||||
|
return "", default_token_count
|
||||||
|
raise ValueError(
|
||||||
|
f"Error getting number of tokens from content list: {e}, default_token_count={default_token_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def token_counter(
|
||||||
|
model="",
|
||||||
|
custom_tokenizer: Optional[Union[dict, SelectTokenizerResponse]] = None,
|
||||||
|
text: Optional[Union[str, List[str]]] = None,
|
||||||
|
messages: Optional[List] = None,
|
||||||
|
count_response_tokens: Optional[bool] = False,
|
||||||
|
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||||
|
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
|
||||||
|
use_default_image_token_count: Optional[bool] = False,
|
||||||
|
default_token_count: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of tokens in a given text using a specified model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name of the model to use for tokenization. Default is an empty string.
|
||||||
|
custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None.
|
||||||
|
text (str): The raw text string to be passed to the model. Default is None.
|
||||||
|
messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None.
|
||||||
|
default_token_count (Optional[int]): The default number of tokens to return for a message block, if an error occurs. Default is None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of tokens in the text.
|
||||||
|
"""
|
||||||
|
# use tiktoken, anthropic, cohere, llama2, or llama3's tokenizer depending on the model
|
||||||
|
is_tool_call = False
|
||||||
|
num_tokens = 0
|
||||||
|
if text is None:
|
||||||
|
if messages is not None:
|
||||||
|
print_verbose(f"token_counter messages received: {messages}")
|
||||||
|
text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message.get("content", None) is not None:
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
text += message["content"]
|
||||||
|
elif isinstance(content, List):
|
||||||
|
text, num_tokens = _get_num_tokens_from_content_list(
|
||||||
|
content_list=content,
|
||||||
|
use_default_image_token_count=use_default_image_token_count,
|
||||||
|
default_token_count=default_token_count,
|
||||||
|
)
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
is_tool_call = True
|
is_tool_call = True
|
||||||
for tool_call in message["tool_calls"]:
|
for tool_call in message["tool_calls"]:
|
||||||
|
@ -1859,6 +1872,7 @@ def token_counter(
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
use_default_image_token_count=use_default_image_token_count
|
use_default_image_token_count=use_default_image_token_count
|
||||||
or False,
|
or False,
|
||||||
|
default_token_count=default_token_count,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -1874,6 +1888,7 @@ def token_counter(
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
use_default_image_token_count=use_default_image_token_count
|
use_default_image_token_count=use_default_image_token_count
|
||||||
or False,
|
or False,
|
||||||
|
default_token_count=default_token_count,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||||
|
|
|
@ -470,3 +470,62 @@ class TestTokenizerSelection(unittest.TestCase):
|
||||||
mock_return_huggingface_tokenizer.assert_not_called()
|
mock_return_huggingface_tokenizer.assert_not_called()
|
||||||
assert result["type"] == "openai_tokenizer"
|
assert result["type"] == "openai_tokenizer"
|
||||||
assert result["tokenizer"] == encoding
|
assert result["tokenizer"] == encoding
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"gpt-4o",
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"messages",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "These are some sample images from a movie. Based on these images, what do you think the tone of the movie is?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://gratisography.com/wp-content/uploads/2024/11/gratisography-augmented-reality-800x525.jpg",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "These are some sample images from a movie. Based on these images, what do you think the tone of the movie is?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://gratisography.com/wp-content/uploads/2024/11/gratisography-augmented-reality-800x525.jpg",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bad_input_token_counter(model, messages):
|
||||||
|
"""
|
||||||
|
Safely handle bad input for token counter.
|
||||||
|
"""
|
||||||
|
token_counter(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
default_token_count=1000,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue