diff --git a/litellm/tests/messages_with_counts.py b/litellm/tests/messages_with_counts.py new file mode 100644 index 000000000..bad19ff01 --- /dev/null +++ b/litellm/tests/messages_with_counts.py @@ -0,0 +1,706 @@ +system_message_short = { + "message": { + "role": "system", + "content": "You are a bot.", + }, + "count": 12 +} + +system_message = { + "message": { + "role": "system", + "content": "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.", + }, + "count": 25 +} + +system_message_long = { + "message": { + "role": "system", + "content": "Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.", + }, + "count": 31 +} + +system_message_unicode = { + "message": { + "role": "system", + "content": "á", + }, + "count": 8 +} + +system_message_with_name = { + "message": { + "role": "system", + "name": "example_user", + "content": "New synergies will help drive top-line growth.", + }, + "count": 20 +} + +user_message = { + "message": { + "role": "user", + "content": "Hello, how are you?", + }, + "count": 13 +} + +user_message_unicode = { + "message": { + "role": "user", + "content": "á", + }, + "count": 8 +} + +user_message_perf = { + "message": { + "role": "user", + "content": "What happens in a performance review?", + }, + "count": 14 +} + +assistant_message_perf = { + "message": { + "role": "assistant", + "content": "During the performance review at Contoso Electronics, the supervisor will discuss the employee's performance over the past year and provide feedback on areas for improvement. They will also provide an opportunity for the employee to discuss their goals and objectives for the upcoming year. The review is a two-way dialogue between managers and employees, and employees will receive a written summary of their performance review which will include a rating of their performance, feedback, and goals and objectives for the upcoming year [employee_handbook-3.pdf].", + }, + "count": 106 +} + +assistant_message_perf_short = { + "message": { + "role": "assistant", + "content": "The supervisor will discuss the employee's performance and provide feedback on areas for improvement. They will also provide an opportunity for the employee to discuss their goals and objectives for the upcoming year. The review is a two-way dialogue between managers and employees, and employees will receive a written summary of their performance review which will include a rating of their performance, feedback, and goals for the upcoming year [employee_handbook-3.pdf].", + }, + "count": 91 +} + +user_message_dresscode = { + "message": { + "role": "user", + "content": "Is there a dress code?", + }, + "count": 13 +} + +assistant_message_dresscode = { + "message": { + "role": "assistant", + "content": "Yes, there is a dress code at Contoso Electronics. Look sharp! [employee_handbook-1.pdf]", + }, + "count": 30 +} + +user_message_pm = { + "message": { + "role": "user", + "content": "What does a Product Manager do?", + }, + "count": 14 +} + +text_and_image_message = { + "message": { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this picture:"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z/C/HgAGgwJ/lK3Q6wAAAABJRU5ErkJggg==", + "detail": "high", + }, + }, + ], + }, + "count": 266 +} + + + +search_sources_toolchoice_auto = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + "parameters": { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", + } + }, + "required": ["search_query"], + }, + }, + } + ], + "tool_choice": "auto", + "count": 66, +} + +search_sources_toolchoice_none = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + "parameters": { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", + } + }, + "required": ["search_query"], + }, + }, + } + ], + "tool_choice": "none", + "count": 67, +} + +search_sources_toolchoice_name = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + "parameters": { + "type": "object", + "properties": { + "search_query": { + "type": "string", + "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", + } + }, + "required": ["search_query"], + }, + }, + } + ], + "tool_choice": {"type": "function", "function": {"name": "search_sources"}}, + "count": 75, +} + +integer_enum = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": {"type": "object", "properties": {"integer_enum": {"type": "integer", "enum": [-1, 1]}}}, + }, + } + ], + "tool_choice": "none", + "count": 54, +} + + +integer_enum_tool_choice_name = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": {"type": "object", "properties": {"integer_enum": {"type": "integer", "enum": [-1, 1]}}}, + }, + } + ], + "tool_choice": { + "type": "function", + "function": {"name": "data_demonstration"}, + }, # 4 tokens for "data_demonstration" + "count": 64, +} + +no_parameters = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + }, + } + ], + "tool_choice": "auto", + "count": 42, +} + +no_parameters_tool_choice_name = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + }, + } + ], + "tool_choice": {"type": "function", "function": {"name": "search_sources"}}, # 2 tokens for "search_sources" + "count": 51, +} + +no_parameter_description_or_required = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + "parameters": {"type": "object", "properties": {"search_query": {"type": "string"}}}, + }, + } + ], + "tool_choice": "auto", + "count": 49, +} + +no_parameter_description = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "search_sources", + "description": "Retrieve sources from the Azure AI Search index", + "parameters": { + "type": "object", + "properties": {"search_query": {"type": "string"}}, + "required": ["search_query"], + }, + }, + } + ], + "tool_choice": "auto", + "count": 49, +} + +string_enum = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "summarize_order", + "description": "Summarize the customer order request", + "parameters": { + "type": "object", + "properties": { + "product_name": { + "type": "string", + "description": "Product name ordered by customer", + }, + "quantity": { + "type": "integer", + "description": "Quantity ordered by customer", + }, + "unit": { + "type": "string", + "enum": ["meals", "days"], + "description": "unit of measurement of the customer order", + }, + }, + "required": ["product_name", "quantity", "unit"], + }, + }, + } + ], + "tool_choice": "none", + "count": 86, +} + +inner_object = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": { + "type": "object", + "properties": { + "object_1": { + "type": "object", + "description": "The object data type as a property", + "properties": { + "string1": {"type": "string"}, + }, + } + }, + "required": ["object_1"], + }, + }, + } + ], + "tool_choice": "none", + "count": 65, # counted 67, over by 2 +} +""" +namespace functions { + +// This is the main function description +type data_demonstration = (_: { +// The object data type as a property +object_1: { + string1?: string, +}, +}) => any; + +} // namespace functions +""" + +inner_object_with_enum_only = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": { + "type": "object", + "properties": { + "object_1": { + "type": "object", + "description": "The object data type as a property", + "properties": {"string_2a": {"type": "string", "enum": ["Happy", "Sad"]}}, + } + }, + "required": ["object_1"], + }, + }, + } + ], + "tool_choice": "none", + "count": 73, # counted 74, over by 1 +} +""" +namespace functions { + +// This is the main function description +type data_demonstration = (_: { +// The object data type as a property +object_1: { + string_2a?: "Happy" | "Sad", +}, +}) => any; + +} // namespace functions +""" + +inner_object_with_enum = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": { + "type": "object", + "properties": { + "object_1": { + "type": "object", + "description": "The object data type as a property", + "properties": { + "string_2a": {"type": "string", "enum": ["Happy", "Sad"]}, + "string_2b": { + "type": "string", + "description": "Description in a second object is lost", + }, + }, + } + }, + "required": ["object_1"], + }, + }, + } + ], + "tool_choice": "none", + "count": 89, # counted 92, over by 3 +} +""" +namespace functions { + +// This is the main function description +type data_demonstration = (_: { +// The object data type as a property +object_1: { + string_2a?: "Happy" | "Sad", + // Description in a second object is lost + string_2b?: string, +}, +}) => any; + +} // namespace functions +""" + +inner_object_and_string = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "data_demonstration", + "description": "This is the main function description", + "parameters": { + "type": "object", + "properties": { + "object_1": { + "type": "object", + "description": "The object data type as a property", + "properties": { + "string_2a": {"type": "string", "enum": ["Happy", "Sad"]}, + "string_2b": { + "type": "string", + "description": "Description in a second object is lost", + }, + }, + }, + "string_1": {"type": "string", "description": "Not required gets a question mark"}, + }, + "required": ["object_1"], + }, + }, + } + ], + "tool_choice": "none", + "count": 103, # counted 106, over by 3 +} +""" +namespace functions { + +// This is the main function description +type data_demonstration = (_: { +// The object data type as a property +object_1: { + string_2a?: "Happy" | "Sad", + // Description in a second object is lost + string_2b?: string, +}, +// Not required gets a question mark +string_1?: string, +}) => any; + +} // namespace functions +""" + +boolean = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "human_escalation", + "description": "Check if user wants to escalate to a human", + "parameters": { + "type": "object", + "properties": { + "requires_escalation": { + "type": "boolean", + "description": "If user is showing signs of frustration or anger in the query. Also if the user says they want to talk to a real person and not a chat bot.", + } + }, + "required": ["requires_escalation"], + }, + }, + } + ], + "tool_choice": "none", + "count": 89, # over by 3 +} + +array = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "get_coordinates", + "description": "Get the latitude and longitude of multiple mailing addresses", + "parameters": { + "type": "object", + "properties": { + "addresses": { + "type": "array", + "description": "The mailing addresses to be located", + "items": {"type": "string"}, + } + }, + "required": ["addresses"], + }, + }, + } + ], + "tool_choice": "none", + "count": 59, +} + +null = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "get_null", + "description": "Get the null value", + "parameters": { + "type": "object", + "properties": { + "null_value": { + "type": "null", + "description": "The null value to be returned", + } + }, + "required": ["null_value"], + }, + }, + } + ], + "tool_choice": "none", + "count": 55, +} + +no_type = { + "system_message": { + "role": "system", + "content": "You are a bot.", + }, + "tools": [ + { + "type": "function", + "function": { + "name": "get_no_type", + "description": "Get the no type value", + "parameters": { + "type": "object", + "properties": { + "no_type_value": { + "description": "The no type value to be returned", + } + }, + "required": ["no_type_value"], + }, + }, + } + ], + "tool_choice": "none", + "count": 59, +} + +MESSAGES_TEXT = [ + system_message, + system_message_short, + system_message_long, + system_message_unicode, + system_message_with_name, + user_message, + user_message_unicode, + user_message_perf, + user_message_dresscode, + user_message_pm, + assistant_message_perf, + assistant_message_perf_short, + assistant_message_dresscode +] + +MESSAGES_WITH_IMAGES = [ + text_and_image_message +] + +MESSAGES_WITH_TOOLS = [ + inner_object, + inner_object_and_string, + inner_object_with_enum_only, + inner_object_with_enum, + search_sources_toolchoice_auto, + search_sources_toolchoice_none, + search_sources_toolchoice_name, + integer_enum, + integer_enum_tool_choice_name, + no_parameters, + no_parameters_tool_choice_name, + no_parameter_description_or_required, + no_parameter_description, + string_enum, + boolean, + array, + no_type, + null, +] \ No newline at end of file diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index e61762131..59d908afe 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -3,15 +3,14 @@ import os import sys -import traceback +import time +from unittest.mock import MagicMock import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import time -from unittest.mock import AsyncMock, MagicMock, patch from litellm import ( create_pretrained_tokenizer, @@ -21,7 +20,7 @@ from litellm import ( token_counter, ) from litellm.tests.large_text import text - +from litellm.tests.messages_with_counts import MESSAGES_TEXT, MESSAGES_WITH_IMAGES, MESSAGES_WITH_TOOLS def test_token_counter_normal_plus_function_calling(): try: @@ -56,9 +55,48 @@ def test_token_counter_normal_plus_function_calling(): except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") - # test_token_counter_normal_plus_function_calling() +@pytest.mark.parametrize( + "message_count_pair", + MESSAGES_TEXT, +) +def test_token_counter_textonly(message_count_pair): + counted_tokens = token_counter( + model="gpt-35-turbo", + messages=[message_count_pair["message"]] + ) + assert counted_tokens == message_count_pair["count"] + +@pytest.mark.parametrize( + "message_count_pair", + MESSAGES_WITH_IMAGES, +) +def test_token_counter_with_images(message_count_pair): + counted_tokens = token_counter( + model="gpt-4o", + messages=[message_count_pair["message"]] + ) + assert counted_tokens == message_count_pair["count"] + + +@pytest.mark.parametrize( + "message_count_pair", + MESSAGES_WITH_TOOLS, +) +def test_token_counter_with_tools(message_count_pair): + counted_tokens = token_counter( + model="gpt-35-turbo", + messages=[message_count_pair["system_message"]], + tools=message_count_pair["tools"], + tool_choice=message_count_pair["tool_choice"], + ) + expected_tokens = message_count_pair["count"] + diff = counted_tokens - expected_tokens + assert ( + diff >= 0 and diff <= 3 + ), f"Expected {expected_tokens} tokens, got {counted_tokens}. Counted tokens is only allowed to be off by 3 in the over-counting direction." + def test_tokenizers(): try: diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 63f07f2ca..42f1dac3d 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -401,6 +401,18 @@ class ChatCompletionToolParam(TypedDict): function: ChatCompletionToolParamFunctionChunk +class Function(TypedDict, total=False): + name: Required[str] + """The name of the function to call.""" + + +class ChatCompletionNamedToolChoiceParam(TypedDict, total=False): + function: Required[Function] + + type: Required[Literal["function"]] + """The type of the tool. Currently, only `function` is supported.""" + + class ChatCompletionRequest(TypedDict, total=False): model: Required[str] messages: Required[List[AllMessageValues]] diff --git a/litellm/utils.py b/litellm/utils.py index c525af477..dc06759eb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -79,6 +79,7 @@ from litellm.types.utils import ( TranscriptionResponse, Usage, ) +from litellm.types.llms.openai import ChatCompletionToolParam, ChatCompletionNamedToolChoiceParam oidc_cache = DualCache() @@ -1571,6 +1572,8 @@ def openai_token_counter( model="gpt-3.5-turbo-0613", text: Optional[str] = None, is_tool_call: Optional[bool] = False, + tools: list[ChatCompletionToolParam] | None = None, + tool_choice: ChatCompletionNamedToolChoiceParam | None = None, count_response_tokens: Optional[ bool ] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter @@ -1605,6 +1608,7 @@ def openai_token_counter( f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" ) num_tokens = 0 + includes_system_message = False if is_tool_call and text is not None: # if it's a tool call we assembled 'text' in token_counter() @@ -1612,6 +1616,8 @@ def openai_token_counter( elif messages is not None: for message in messages: num_tokens += tokens_per_message + if message.get("role", None) == "system": + includes_system_message = True for key, value in message.items(): if isinstance(value, str): num_tokens += len(encoding.encode(value, disallowed_special=())) @@ -1629,12 +1635,12 @@ def openai_token_counter( image_url_dict = c["image_url"] detail = image_url_dict.get("detail", "auto") url = image_url_dict.get("url") - num_tokens += calculage_img_tokens( + num_tokens += _calculate_img_tokens( data=url, mode=detail ) elif isinstance(c["image_url"], str): image_url_str = c["image_url"] - num_tokens += calculage_img_tokens( + num_tokens += _calculate_img_tokens( data=image_url_str, mode="auto" ) elif text is not None and count_response_tokens == True: @@ -1644,6 +1650,22 @@ def openai_token_counter( elif text is not None: num_tokens = len(encoding.encode(text, disallowed_special=())) num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + + if tools: + num_tokens += len(encoding.encode(_format_function_definitions(tools))) + num_tokens += 9 # Additional tokens for function definition of tools + # If there's a system message and tools are present, subtract four tokens + if tools and includes_system_message: + num_tokens -= 4 + # If tool_choice is 'none', add one token. + # If it's an object, add 4 + the number of tokens in the function name. + # If it's undefined or 'auto', don't add anything. + if tool_choice == "none": + num_tokens += 1 + elif isinstance(tool_choice, dict): + num_tokens += 7 + num_tokens += len(encoding.encode(tool_choice["function"]["name"])) + return num_tokens @@ -1652,6 +1674,10 @@ def resize_image_high_res(width, height): max_short_side = 768 max_long_side = 2000 + # Return early if no resizing is needed + if width <= 768 and height <= 768: + return width, height + # Determine the longer and shorter sides longer_side = max(width, height) shorter_side = min(width, height) @@ -1723,7 +1749,7 @@ def get_image_dimensions(data): return None, None -def calculage_img_tokens( +def _calculate_img_tokens( data, mode: Literal["low", "high", "auto"] = "auto", base_tokens: int = 85, # openai default - https://openai.com/pricing @@ -1776,6 +1802,70 @@ def create_tokenizer(json: str): tokenizer = Tokenizer.from_str(json) return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} +# Based on https://github.com/forestwanglin/openai-java/blob/main/jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java + + +def _format_function_definitions(tools): + lines = [] + lines.append("namespace functions {") + lines.append("") + for tool in tools: + function = tool.get("function") + if function_description := function.get("description"): + lines.append(f"// {function_description}") + function_name = function.get("name") + parameters = function.get("parameters", {}) + properties = parameters.get("properties") + if properties and properties.keys(): + lines.append(f"type {function_name} = (_: {{") + lines.append(_format_object_parameters(parameters, 0)) + lines.append("}) => any;") + else: + lines.append(f"type {function_name} = () => any;") + lines.append("") + lines.append("} // namespace functions") + return "\n".join(lines) + + +def _format_object_parameters(parameters, indent): + properties = parameters.get("properties") + if not properties: + return "" + required_params = parameters.get("required", []) + lines = [] + for key, props in properties.items(): + description = props.get("description") + if description: + lines.append(f"// {description}") + question = "?" + if required_params and key in required_params: + question = "" + lines.append(f"{key}{question}: {_format_type(props, indent)},") + return "\n".join([" " * max(0, indent) + line for line in lines]) + + +def _format_type(props, indent): + type = props.get("type") + if type == "string": + if "enum" in props: + return " | ".join([f'"{item}"' for item in props["enum"]]) + return "string" + elif type == "array": + # items is required, OpenAI throws an error if it's missing + return f"{_format_type(props['items'], indent)}[]" + elif type == "object": + return f"{{\n{_format_object_parameters(props, indent + 2)}\n}}" + elif type in ["integer", "number"]: + if "enum" in props: + return " | ".join([f'"{item}"' for item in props["enum"]]) + return "number" + elif type == "boolean": + return "boolean" + elif type == "null": + return "null" + else: + # This is a guess, as an empty string doesn't yield the expected token count + return "any" def token_counter( model="", @@ -1783,6 +1873,8 @@ def token_counter( text: Optional[Union[str, List[str]]] = None, messages: Optional[List] = None, count_response_tokens: Optional[bool] = False, + tools: list[ChatCompletionToolParam] | None = None, + tool_choice: ChatCompletionNamedToolChoiceParam | None = None, ) -> int: """ Count the number of tokens in a given text using a specified model. @@ -1817,12 +1909,12 @@ def token_counter( image_url_dict = c["image_url"] detail = image_url_dict.get("detail", "auto") url = image_url_dict.get("url") - num_tokens += calculage_img_tokens( + num_tokens += _calculate_img_tokens( data=url, mode=detail ) elif isinstance(c["image_url"], str): image_url_str = c["image_url"] - num_tokens += calculage_img_tokens( + num_tokens += _calculate_img_tokens( data=image_url_str, mode="auto" ) if "tool_calls" in message: @@ -1861,6 +1953,8 @@ def token_counter( messages=messages, is_tool_call=is_tool_call, count_response_tokens=count_response_tokens, + tools=tools, + tool_choice=tool_choice ) else: print_verbose( @@ -1872,6 +1966,8 @@ def token_counter( messages=messages, is_tool_call=is_tool_call, count_response_tokens=count_response_tokens, + tools=tools, + tool_choice=tool_choice ) else: num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore @@ -1892,7 +1988,7 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool: def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool: """ - Check if the given model supports function calling and return a boolean value. + Check if the given model supports system messages and return a boolean value. Parameters: model (str): The model name to be checked.