Improved trimming logic and OpenAI token counter

This commit is contained in:
Duc Pham 2023-11-10 01:26:13 +07:00
parent 70311502c8
commit 07e8cf1d9a
2 changed files with 122 additions and 24 deletions

View file

@ -1,6 +1,6 @@
import sys, os import sys, os
import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import copy
load_dotenv() load_dotenv()
import os import os
@ -38,7 +38,7 @@ def test_multiple_messages_trimming():
{"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is a long message that will exceed the token limit."},
{"role": "user", "content": "This is another long message that will also exceed the limit."} {"role": "user", "content": "This is another long message that will also exceed the limit."}
] ]
trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=20)
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
# test_multiple_messages_trimming() # test_multiple_messages_trimming()
@ -48,7 +48,7 @@ def test_multiple_messages_no_trimming():
{"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is a long message that will exceed the token limit."},
{"role": "user", "content": "This is another long message that will also exceed the limit."} {"role": "user", "content": "This is another long message that will also exceed the limit."}
] ]
trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100) trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=100)
print("Trimmed messages") print("Trimmed messages")
print(trimmed_messages) print(trimmed_messages)
assert(messages==trimmed_messages) assert(messages==trimmed_messages)
@ -56,14 +56,42 @@ def test_multiple_messages_no_trimming():
# test_multiple_messages_no_trimming() # test_multiple_messages_no_trimming()
def test_large_trimming(): def test_large_trimming_multiple_messages():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=20, model="random") trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613")
print("trimmed messages") print("trimmed messages")
print(trimmed_messages) print(trimmed_messages)
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20 assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20
# test_large_trimming() # test_large_trimming()
def test_large_trimming_single_message():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613")
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0
def test_trimming_with_system_message_within_max_tokens():
# This message is 33 tokens long
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit
assert len(trimmed_messages) == 2
assert trimmed_messages[0]["content"] == "This is a short system message"
def test_trimming_with_system_message_exceeding_max_tokens():
# This message is 33 tokens long. The system message is 13 tokens long.
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert len(trimmed_messages) == 1
assert '..' in trimmed_messages[0]["content"]
def test_trimming_should_not_change_original_messages():
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
messages_copy = copy.deepcopy(messages)
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert(messages==messages_copy)
def test_get_valid_models(): def test_get_valid_models():
old_environ = os.environ old_environ = os.environ
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ

View file

@ -19,6 +19,7 @@ import uuid
import aiohttp import aiohttp
import logging import logging
import asyncio import asyncio
import copy
from tokenizers import Tokenizer from tokenizers import Tokenizer
from dataclasses import ( from dataclasses import (
dataclass, dataclass,
@ -1101,6 +1102,50 @@ def decode(model: str, tokens: List[int]):
dec = tokenizer_json["tokenizer"].decode(tokens) dec = tokenizer_json["tokenizer"].decode(tokens)
return dec return dec
def openai_token_counter(messages, model="gpt-3.5-turbo-0613"):
"""
Return the number of tokens used by a list of messages.
Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb.
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return openai_token_counter(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return openai_token_counter(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def token_counter(model="", text=None, messages: Optional[List] = None): def token_counter(model="", text=None, messages: Optional[List] = None):
""" """
Count the number of tokens in a given text using a specified model. Count the number of tokens in a given text using a specified model.
@ -1127,6 +1172,9 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
enc = tokenizer_json["tokenizer"].encode(text) enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc.ids) num_tokens = len(enc.ids)
elif tokenizer_json["type"] == "openai_tokenizer": elif tokenizer_json["type"] == "openai_tokenizer":
if messages is not None:
num_tokens = openai_token_counter(messages, model=model)
else:
enc = tokenizer_json["tokenizer"].encode(text) enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc) num_tokens = len(enc)
else: else:
@ -4429,7 +4477,7 @@ def completion_with_config(config: Union[dict, str], **kwargs):
except: except:
continue continue
if prompt_larger_than_model: if prompt_larger_than_model:
messages = trim_messages(messages=messages, model=max_model) messages = trim_messages(messages_copy=messages, model=max_model)
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["model"] = model kwargs["model"] = model
@ -4528,13 +4576,13 @@ def completion_with_fallbacks(**kwargs):
def process_system_message(system_message, max_tokens, model): def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message} system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count(system_message_event, model) system_message_tokens = get_token_count([system_message_event], model)
if system_message_tokens > max_tokens: if system_message_tokens > max_tokens:
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
# shorten system message to fit within max_tokens # shorten system message to fit within max_tokens
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
system_message_tokens = get_token_count(new_system_message, model) system_message_tokens = get_token_count([new_system_message], model)
return system_message_event, max_tokens - system_message_tokens return system_message_event, max_tokens - system_message_tokens
@ -4544,11 +4592,15 @@ def process_messages(messages, max_tokens, model):
final_messages = [] final_messages = []
for message in messages: for message in messages:
final_messages = attempt_message_addition(final_messages, message, max_tokens, model) used_tokens = get_token_count(final_messages, model)
available_tokens = max_tokens - used_tokens
if available_tokens <= 3:
break
final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model)
return final_messages return final_messages
def attempt_message_addition(final_messages, message, max_tokens, model): def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model):
temp_messages = [message] + final_messages temp_messages = [message] + final_messages
temp_message_tokens = get_token_count(messages=temp_messages, model=model) temp_message_tokens = get_token_count(messages=temp_messages, model=model)
@ -4558,7 +4610,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model):
# if temp_message_tokens > max_tokens, try shortening temp_messages # if temp_message_tokens > max_tokens, try shortening temp_messages
elif "function_call" not in message: elif "function_call" not in message:
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model) updated_message = shorten_message_to_fit_limit(message, available_tokens, model)
if can_add_message(updated_message, final_messages, max_tokens, model): if can_add_message(updated_message, final_messages, max_tokens, model):
return [updated_message] + final_messages return [updated_message] + final_messages
@ -4580,6 +4632,13 @@ def shorten_message_to_fit_limit(
""" """
Shorten a message to fit within a token limit by removing characters from the middle. Shorten a message to fit within a token limit by removing characters from the middle.
""" """
# For OpenAI models, even blank messages cost 7 token,
# and if the buffer is less than 3, the while loop will never end,
# hence the value 10.
if 'gpt' in model and tokens_needed <= 10:
return message
content = message["content"] content = message["content"]
while True: while True:
@ -4607,7 +4666,7 @@ def shorten_message_to_fit_limit(
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
# Credits for this code go to Killian Lucas # Credits for this code go to Killian Lucas
def trim_messages( def trim_messages(
messages, messages_copy,
model: Optional[str] = None, model: Optional[str] = None,
trim_ratio: float = 0.75, trim_ratio: float = 0.75,
return_response_tokens: bool = False, return_response_tokens: bool = False,
@ -4628,6 +4687,7 @@ def trim_messages(
""" """
# Initialize max_tokens # Initialize max_tokens
# if users pass in max tokens, trim to this amount # if users pass in max tokens, trim to this amount
messages_copy = copy.deepcopy(messages_copy)
try: try:
print_verbose(f"trimming messages") print_verbose(f"trimming messages")
if max_tokens == None: if max_tokens == None:
@ -4642,33 +4702,43 @@ def trim_messages(
return return
system_message = "" system_message = ""
for message in messages: for message in messages_copy:
if message["role"] == "system": if message["role"] == "system":
system_message += '\n' if system_message else ''
system_message += message["content"] system_message += message["content"]
current_tokens = token_counter(model=model, messages=messages) current_tokens = token_counter(model=model, messages=messages_copy)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages # Do nothing if current tokens under messages
if current_tokens < max_tokens: if current_tokens < max_tokens:
return messages return messages_copy
#### Trimming messages if current_tokens > max_tokens #### Trimming messages if current_tokens > max_tokens
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
if system_message: if system_message:
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
messages = messages + [system_message_event]
final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) if max_tokens == 0: # the system messages are too long
return [system_message_event]
# Since all system messages are combined and trimmed to fit the max_tokens,
# we remove all system messages from the messages list
messages_copy = [message for message in messages_copy if message["role"] != "system"]
final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model)
# Add system message to the beginning of the final messages
if system_message:
final_messages = [system_message_event] + final_messages
if return_response_tokens: # if user wants token count with new trimmed messages if return_response_tokens: # if user wants token count with new trimmed messages
response_tokens = max_tokens - get_token_count(final_messages, model) response_tokens = max_tokens - get_token_count(final_messages, model)
return final_messages, response_tokens return final_messages, response_tokens
return final_messages return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}") print_verbose(f"Got exception while token trimming{e}")
return messages return messages_copy
def get_valid_models(): def get_valid_models():
""" """