mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #4716 from pamelafox/countfuncs
Add token counting for OpenAI tools/tool_choice
This commit is contained in:
commit
0fb88e527c
4 changed files with 860 additions and 6 deletions
|
@ -79,6 +79,7 @@ from litellm.types.utils import (
|
|||
TranscriptionResponse,
|
||||
Usage,
|
||||
)
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam, ChatCompletionNamedToolChoiceParam
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
@ -1571,6 +1572,8 @@ def openai_token_counter(
|
|||
model="gpt-3.5-turbo-0613",
|
||||
text: Optional[str] = None,
|
||||
is_tool_call: Optional[bool] = False,
|
||||
tools: list[ChatCompletionToolParam] | None = None,
|
||||
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
|
||||
count_response_tokens: Optional[
|
||||
bool
|
||||
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
||||
|
@ -1605,6 +1608,7 @@ def openai_token_counter(
|
|||
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
includes_system_message = False
|
||||
|
||||
if is_tool_call and text is not None:
|
||||
# if it's a tool call we assembled 'text' in token_counter()
|
||||
|
@ -1612,6 +1616,8 @@ def openai_token_counter(
|
|||
elif messages is not None:
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
if message.get("role", None) == "system":
|
||||
includes_system_message = True
|
||||
for key, value in message.items():
|
||||
if isinstance(value, str):
|
||||
num_tokens += len(encoding.encode(value, disallowed_special=()))
|
||||
|
@ -1644,6 +1650,22 @@ def openai_token_counter(
|
|||
elif text is not None:
|
||||
num_tokens = len(encoding.encode(text, disallowed_special=()))
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
||||
if tools:
|
||||
num_tokens += len(encoding.encode(_format_function_definitions(tools)))
|
||||
num_tokens += 9 # Additional tokens for function definition of tools
|
||||
# If there's a system message and tools are present, subtract four tokens
|
||||
if tools and includes_system_message:
|
||||
num_tokens -= 4
|
||||
# If tool_choice is 'none', add one token.
|
||||
# If it's an object, add 4 + the number of tokens in the function name.
|
||||
# If it's undefined or 'auto', don't add anything.
|
||||
if tool_choice == "none":
|
||||
num_tokens += 1
|
||||
elif isinstance(tool_choice, dict):
|
||||
num_tokens += 7
|
||||
num_tokens += len(encoding.encode(tool_choice["function"]["name"]))
|
||||
|
||||
return num_tokens
|
||||
|
||||
|
||||
|
@ -1652,6 +1674,10 @@ def resize_image_high_res(width, height):
|
|||
max_short_side = 768
|
||||
max_long_side = 2000
|
||||
|
||||
# Return early if no resizing is needed
|
||||
if width <= 768 and height <= 768:
|
||||
return width, height
|
||||
|
||||
# Determine the longer and shorter sides
|
||||
longer_side = max(width, height)
|
||||
shorter_side = min(width, height)
|
||||
|
@ -1777,12 +1803,79 @@ def create_tokenizer(json: str):
|
|||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _format_function_definitions(tools):
|
||||
"""Formats tool definitions in the format that OpenAI appears to use.
|
||||
Based on https://github.com/forestwanglin/openai-java/blob/main/jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java
|
||||
"""
|
||||
lines = []
|
||||
lines.append("namespace functions {")
|
||||
lines.append("")
|
||||
for tool in tools:
|
||||
function = tool.get("function")
|
||||
if function_description := function.get("description"):
|
||||
lines.append(f"// {function_description}")
|
||||
function_name = function.get("name")
|
||||
parameters = function.get("parameters", {})
|
||||
properties = parameters.get("properties")
|
||||
if properties and properties.keys():
|
||||
lines.append(f"type {function_name} = (_: {{")
|
||||
lines.append(_format_object_parameters(parameters, 0))
|
||||
lines.append("}) => any;")
|
||||
else:
|
||||
lines.append(f"type {function_name} = () => any;")
|
||||
lines.append("")
|
||||
lines.append("} // namespace functions")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_object_parameters(parameters, indent):
|
||||
properties = parameters.get("properties")
|
||||
if not properties:
|
||||
return ""
|
||||
required_params = parameters.get("required", [])
|
||||
lines = []
|
||||
for key, props in properties.items():
|
||||
description = props.get("description")
|
||||
if description:
|
||||
lines.append(f"// {description}")
|
||||
question = "?"
|
||||
if required_params and key in required_params:
|
||||
question = ""
|
||||
lines.append(f"{key}{question}: {_format_type(props, indent)},")
|
||||
return "\n".join([" " * max(0, indent) + line for line in lines])
|
||||
|
||||
|
||||
def _format_type(props, indent):
|
||||
type = props.get("type")
|
||||
if type == "string":
|
||||
if "enum" in props:
|
||||
return " | ".join([f'"{item}"' for item in props["enum"]])
|
||||
return "string"
|
||||
elif type == "array":
|
||||
# items is required, OpenAI throws an error if it's missing
|
||||
return f"{_format_type(props['items'], indent)}[]"
|
||||
elif type == "object":
|
||||
return f"{{\n{_format_object_parameters(props, indent + 2)}\n}}"
|
||||
elif type in ["integer", "number"]:
|
||||
if "enum" in props:
|
||||
return " | ".join([f'"{item}"' for item in props["enum"]])
|
||||
return "number"
|
||||
elif type == "boolean":
|
||||
return "boolean"
|
||||
elif type == "null":
|
||||
return "null"
|
||||
else:
|
||||
# This is a guess, as an empty string doesn't yield the expected token count
|
||||
return "any"
|
||||
|
||||
def token_counter(
|
||||
model="",
|
||||
custom_tokenizer: Optional[dict] = None,
|
||||
text: Optional[Union[str, List[str]]] = None,
|
||||
messages: Optional[List] = None,
|
||||
count_response_tokens: Optional[bool] = False,
|
||||
tools: list[ChatCompletionToolParam] | None = None,
|
||||
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count the number of tokens in a given text using a specified model.
|
||||
|
@ -1861,6 +1954,8 @@ def token_counter(
|
|||
messages=messages,
|
||||
is_tool_call=is_tool_call,
|
||||
count_response_tokens=count_response_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice
|
||||
)
|
||||
else:
|
||||
print_verbose(
|
||||
|
@ -1872,6 +1967,8 @@ def token_counter(
|
|||
messages=messages,
|
||||
is_tool_call=is_tool_call,
|
||||
count_response_tokens=count_response_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice
|
||||
)
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||
|
@ -1892,7 +1989,7 @@ def supports_httpx_timeout(custom_llm_provider: str) -> bool:
|
|||
|
||||
def supports_system_messages(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
Check if the given model supports system messages and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue