Count tokens for tools

This commit is contained in:
Pamela Fox 2024-07-15 11:07:52 -07:00
parent 3dc2ec8119
commit d43dbc756b
4 changed files with 863 additions and 11 deletions

View file

@ -79,6 +79,7 @@ from litellm.types.utils import (
TranscriptionResponse,
Usage,
)
from litellm.types.llms.openai import ChatCompletionToolParam, ChatCompletionNamedToolChoiceParam
oidc_cache = DualCache()
@ -1571,6 +1572,8 @@ def openai_token_counter(
model="gpt-3.5-turbo-0613",
text: Optional[str] = None,
is_tool_call: Optional[bool] = False,
tools: list[ChatCompletionToolParam] | None = None,
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
count_response_tokens: Optional[
bool
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
@ -1605,6 +1608,7 @@ def openai_token_counter(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
includes_system_message = False
if is_tool_call and text is not None:
# if it's a tool call we assembled 'text' in token_counter()
@ -1612,6 +1616,8 @@ def openai_token_counter(
elif messages is not None:
for message in messages:
num_tokens += tokens_per_message
if message.get("role", None) == "system":
includes_system_message = True
for key, value in message.items():
if isinstance(value, str):
num_tokens += len(encoding.encode(value, disallowed_special=()))
@ -1629,12 +1635,12 @@ def openai_token_counter(
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
num_tokens += _calculate_img_tokens(
data=url, mode=detail
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
num_tokens += _calculate_img_tokens(
data=image_url_str, mode="auto"
)
elif text is not None and count_response_tokens == True:
@ -1644,6 +1650,22 @@ def openai_token_counter(
elif text is not None:
num_tokens = len(encoding.encode(text, disallowed_special=()))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
if tools:
num_tokens += len(encoding.encode(_format_function_definitions(tools)))
num_tokens += 9 # Additional tokens for function definition of tools
# If there's a system message and tools are present, subtract four tokens
if tools and includes_system_message:
num_tokens -= 4
# If tool_choice is 'none', add one token.
# If it's an object, add 4 + the number of tokens in the function name.
# If it's undefined or 'auto', don't add anything.
if tool_choice == "none":
num_tokens += 1
elif isinstance(tool_choice, dict):
num_tokens += 7
num_tokens += len(encoding.encode(tool_choice["function"]["name"]))
return num_tokens
@ -1652,6 +1674,10 @@ def resize_image_high_res(width, height):
max_short_side = 768
max_long_side = 2000
# Return early if no resizing is needed
if width <= 768 and height <= 768:
return width, height
# Determine the longer and shorter sides
longer_side = max(width, height)
shorter_side = min(width, height)
@ -1723,7 +1749,7 @@ def get_image_dimensions(data):
return None, None
def calculage_img_tokens(
def _calculate_img_tokens(
data,
mode: Literal["low", "high", "auto"] = "auto",
base_tokens: int = 85, # openai default - https://openai.com/pricing
@ -1776,6 +1802,70 @@ def create_tokenizer(json: str):
tokenizer = Tokenizer.from_str(json)
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
# Based on https://github.com/forestwanglin/openai-java/blob/main/jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java
def _format_function_definitions(tools):
lines = []
lines.append("namespace functions {")
lines.append("")
for tool in tools:
function = tool.get("function")
if function_description := function.get("description"):
lines.append(f"// {function_description}")
function_name = function.get("name")
parameters = function.get("parameters", {})
properties = parameters.get("properties")
if properties and properties.keys():
lines.append(f"type {function_name} = (_: {{")
lines.append(_format_object_parameters(parameters, 0))
lines.append("}) => any;")
else:
lines.append(f"type {function_name} = () => any;")
lines.append("")
lines.append("} // namespace functions")
return "\n".join(lines)
def _format_object_parameters(parameters, indent):
properties = parameters.get("properties")
if not properties:
return ""
required_params = parameters.get("required", [])
lines = []
for key, props in properties.items():
description = props.get("description")
if description:
lines.append(f"// {description}")
question = "?"
if required_params and key in required_params:
question = ""
lines.append(f"{key}{question}: {_format_type(props, indent)},")
return "\n".join([" " * max(0, indent) + line for line in lines])
def _format_type(props, indent):
type = props.get("type")
if type == "string":
if "enum" in props:
return " | ".join([f'"{item}"' for item in props["enum"]])
return "string"
elif type == "array":
# items is required, OpenAI throws an error if it's missing
return f"{_format_type(props['items'], indent)}[]"
elif type == "object":
return f"{{\n{_format_object_parameters(props, indent + 2)}\n}}"
elif type in ["integer", "number"]:
if "enum" in props:
return " | ".join([f'"{item}"' for item in props["enum"]])
return "number"
elif type == "boolean":
return "boolean"
elif type == "null":
return "null"
else:
# This is a guess, as an empty string doesn't yield the expected token count
return "any"
def token_counter(
model="",
@ -1783,6 +1873,8 @@ def token_counter(
text: Optional[Union[str, List[str]]] = None,
messages: Optional[List] = None,
count_response_tokens: Optional[bool] = False,
tools: list[ChatCompletionToolParam] | None = None,
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
) -> int:
"""
Count the number of tokens in a given text using a specified model.
@ -1817,12 +1909,12 @@ def token_counter(
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
num_tokens += _calculate_img_tokens(
data=url, mode=detail
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
num_tokens += _calculate_img_tokens(
data=image_url_str, mode="auto"
)
if "tool_calls" in message:
@ -1861,6 +1953,8 @@ def token_counter(
messages=messages,
is_tool_call=is_tool_call,
count_response_tokens=count_response_tokens,
tools=tools,
tool_choice=tool_choice
)
else:
print_verbose(
@ -1872,6 +1966,8 @@ def token_counter(
messages=messages,
is_tool_call=is_tool_call,
count_response_tokens=count_response_tokens,
tools=tools,
tool_choice=tool_choice
)
else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
@ -1892,7 +1988,7 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
Check if the given model supports system messages and return a boolean value.
Parameters:
model (str): The model name to be checked.