mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Improved trimming logic and OpenAI token counter
This commit is contained in:
parent
70311502c8
commit
07e8cf1d9a
2 changed files with 122 additions and 24 deletions
|
@ -1,6 +1,6 @@
|
||||||
import sys, os
|
import sys, os
|
||||||
import traceback
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import copy
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
import os
|
||||||
|
@ -38,7 +38,7 @@ def test_multiple_messages_trimming():
|
||||||
{"role": "user", "content": "This is a long message that will exceed the token limit."},
|
{"role": "user", "content": "This is a long message that will exceed the token limit."},
|
||||||
{"role": "user", "content": "This is another long message that will also exceed the limit."}
|
{"role": "user", "content": "This is another long message that will also exceed the limit."}
|
||||||
]
|
]
|
||||||
trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20)
|
trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=20)
|
||||||
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
||||||
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
||||||
# test_multiple_messages_trimming()
|
# test_multiple_messages_trimming()
|
||||||
|
@ -48,7 +48,7 @@ def test_multiple_messages_no_trimming():
|
||||||
{"role": "user", "content": "This is a long message that will exceed the token limit."},
|
{"role": "user", "content": "This is a long message that will exceed the token limit."},
|
||||||
{"role": "user", "content": "This is another long message that will also exceed the limit."}
|
{"role": "user", "content": "This is another long message that will also exceed the limit."}
|
||||||
]
|
]
|
||||||
trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100)
|
trimmed_messages = trim_messages(messages_copy=messages, model="gpt-3.5-turbo", max_tokens=100)
|
||||||
print("Trimmed messages")
|
print("Trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(messages==trimmed_messages)
|
assert(messages==trimmed_messages)
|
||||||
|
@ -56,14 +56,42 @@ def test_multiple_messages_no_trimming():
|
||||||
# test_multiple_messages_no_trimming()
|
# test_multiple_messages_no_trimming()
|
||||||
|
|
||||||
|
|
||||||
def test_large_trimming():
|
def test_large_trimming_multiple_messages():
|
||||||
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
|
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
|
||||||
trimmed_messages = trim_messages(messages, max_tokens=20, model="random")
|
trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613")
|
||||||
print("trimmed messages")
|
print("trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20
|
||||||
# test_large_trimming()
|
# test_large_trimming()
|
||||||
|
|
||||||
|
def test_large_trimming_single_message():
|
||||||
|
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
|
||||||
|
trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613")
|
||||||
|
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5
|
||||||
|
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_trimming_with_system_message_within_max_tokens():
|
||||||
|
# This message is 33 tokens long
|
||||||
|
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
|
||||||
|
trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit
|
||||||
|
assert len(trimmed_messages) == 2
|
||||||
|
assert trimmed_messages[0]["content"] == "This is a short system message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_trimming_with_system_message_exceeding_max_tokens():
|
||||||
|
# This message is 33 tokens long. The system message is 13 tokens long.
|
||||||
|
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
|
||||||
|
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
|
||||||
|
assert len(trimmed_messages) == 1
|
||||||
|
assert '..' in trimmed_messages[0]["content"]
|
||||||
|
|
||||||
|
def test_trimming_should_not_change_original_messages():
|
||||||
|
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
|
||||||
|
messages_copy = copy.deepcopy(messages)
|
||||||
|
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
|
||||||
|
assert(messages==messages_copy)
|
||||||
|
|
||||||
def test_get_valid_models():
|
def test_get_valid_models():
|
||||||
old_environ = os.environ
|
old_environ = os.environ
|
||||||
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ
|
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ
|
||||||
|
|
100
litellm/utils.py
100
litellm/utils.py
|
@ -19,6 +19,7 @@ import uuid
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
|
@ -1101,6 +1102,50 @@ def decode(model: str, tokens: List[int]):
|
||||||
dec = tokenizer_json["tokenizer"].decode(tokens)
|
dec = tokenizer_json["tokenizer"].decode(tokens)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
def openai_token_counter(messages, model="gpt-3.5-turbo-0613"):
|
||||||
|
"""
|
||||||
|
Return the number of tokens used by a list of messages.
|
||||||
|
|
||||||
|
Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
print("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
if model in {
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
}:
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
elif model == "gpt-3.5-turbo-0301":
|
||||||
|
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||||
|
elif "gpt-3.5-turbo" in model:
|
||||||
|
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||||
|
return openai_token_counter(messages, model="gpt-3.5-turbo-0613")
|
||||||
|
elif "gpt-4" in model:
|
||||||
|
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||||
|
return openai_token_counter(messages, model="gpt-4-0613")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||||
|
)
|
||||||
|
num_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
def token_counter(model="", text=None, messages: Optional[List] = None):
|
def token_counter(model="", text=None, messages: Optional[List] = None):
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in a given text using a specified model.
|
Count the number of tokens in a given text using a specified model.
|
||||||
|
@ -1127,6 +1172,9 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
|
||||||
enc = tokenizer_json["tokenizer"].encode(text)
|
enc = tokenizer_json["tokenizer"].encode(text)
|
||||||
num_tokens = len(enc.ids)
|
num_tokens = len(enc.ids)
|
||||||
elif tokenizer_json["type"] == "openai_tokenizer":
|
elif tokenizer_json["type"] == "openai_tokenizer":
|
||||||
|
if messages is not None:
|
||||||
|
num_tokens = openai_token_counter(messages, model=model)
|
||||||
|
else:
|
||||||
enc = tokenizer_json["tokenizer"].encode(text)
|
enc = tokenizer_json["tokenizer"].encode(text)
|
||||||
num_tokens = len(enc)
|
num_tokens = len(enc)
|
||||||
else:
|
else:
|
||||||
|
@ -4429,7 +4477,7 @@ def completion_with_config(config: Union[dict, str], **kwargs):
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
if prompt_larger_than_model:
|
if prompt_larger_than_model:
|
||||||
messages = trim_messages(messages=messages, model=max_model)
|
messages = trim_messages(messages_copy=messages, model=max_model)
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
|
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
@ -4528,13 +4576,13 @@ def completion_with_fallbacks(**kwargs):
|
||||||
|
|
||||||
def process_system_message(system_message, max_tokens, model):
|
def process_system_message(system_message, max_tokens, model):
|
||||||
system_message_event = {"role": "system", "content": system_message}
|
system_message_event = {"role": "system", "content": system_message}
|
||||||
system_message_tokens = get_token_count(system_message_event, model)
|
system_message_tokens = get_token_count([system_message_event], model)
|
||||||
|
|
||||||
if system_message_tokens > max_tokens:
|
if system_message_tokens > max_tokens:
|
||||||
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
|
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
|
||||||
# shorten system message to fit within max_tokens
|
# shorten system message to fit within max_tokens
|
||||||
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
|
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
|
||||||
system_message_tokens = get_token_count(new_system_message, model)
|
system_message_tokens = get_token_count([new_system_message], model)
|
||||||
|
|
||||||
return system_message_event, max_tokens - system_message_tokens
|
return system_message_event, max_tokens - system_message_tokens
|
||||||
|
|
||||||
|
@ -4544,11 +4592,15 @@ def process_messages(messages, max_tokens, model):
|
||||||
final_messages = []
|
final_messages = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
final_messages = attempt_message_addition(final_messages, message, max_tokens, model)
|
used_tokens = get_token_count(final_messages, model)
|
||||||
|
available_tokens = max_tokens - used_tokens
|
||||||
|
if available_tokens <= 3:
|
||||||
|
break
|
||||||
|
final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model)
|
||||||
|
|
||||||
return final_messages
|
return final_messages
|
||||||
|
|
||||||
def attempt_message_addition(final_messages, message, max_tokens, model):
|
def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model):
|
||||||
temp_messages = [message] + final_messages
|
temp_messages = [message] + final_messages
|
||||||
temp_message_tokens = get_token_count(messages=temp_messages, model=model)
|
temp_message_tokens = get_token_count(messages=temp_messages, model=model)
|
||||||
|
|
||||||
|
@ -4558,7 +4610,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model):
|
||||||
# if temp_message_tokens > max_tokens, try shortening temp_messages
|
# if temp_message_tokens > max_tokens, try shortening temp_messages
|
||||||
elif "function_call" not in message:
|
elif "function_call" not in message:
|
||||||
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
|
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
|
||||||
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model)
|
updated_message = shorten_message_to_fit_limit(message, available_tokens, model)
|
||||||
if can_add_message(updated_message, final_messages, max_tokens, model):
|
if can_add_message(updated_message, final_messages, max_tokens, model):
|
||||||
return [updated_message] + final_messages
|
return [updated_message] + final_messages
|
||||||
|
|
||||||
|
@ -4580,6 +4632,13 @@ def shorten_message_to_fit_limit(
|
||||||
"""
|
"""
|
||||||
Shorten a message to fit within a token limit by removing characters from the middle.
|
Shorten a message to fit within a token limit by removing characters from the middle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# For OpenAI models, even blank messages cost 7 token,
|
||||||
|
# and if the buffer is less than 3, the while loop will never end,
|
||||||
|
# hence the value 10.
|
||||||
|
if 'gpt' in model and tokens_needed <= 10:
|
||||||
|
return message
|
||||||
|
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -4607,7 +4666,7 @@ def shorten_message_to_fit_limit(
|
||||||
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
|
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
|
||||||
# Credits for this code go to Killian Lucas
|
# Credits for this code go to Killian Lucas
|
||||||
def trim_messages(
|
def trim_messages(
|
||||||
messages,
|
messages_copy,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
trim_ratio: float = 0.75,
|
trim_ratio: float = 0.75,
|
||||||
return_response_tokens: bool = False,
|
return_response_tokens: bool = False,
|
||||||
|
@ -4628,6 +4687,7 @@ def trim_messages(
|
||||||
"""
|
"""
|
||||||
# Initialize max_tokens
|
# Initialize max_tokens
|
||||||
# if users pass in max tokens, trim to this amount
|
# if users pass in max tokens, trim to this amount
|
||||||
|
messages_copy = copy.deepcopy(messages_copy)
|
||||||
try:
|
try:
|
||||||
print_verbose(f"trimming messages")
|
print_verbose(f"trimming messages")
|
||||||
if max_tokens == None:
|
if max_tokens == None:
|
||||||
|
@ -4642,33 +4702,43 @@ def trim_messages(
|
||||||
return
|
return
|
||||||
|
|
||||||
system_message = ""
|
system_message = ""
|
||||||
for message in messages:
|
for message in messages_copy:
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
|
system_message += '\n' if system_message else ''
|
||||||
system_message += message["content"]
|
system_message += message["content"]
|
||||||
|
|
||||||
current_tokens = token_counter(model=model, messages=messages)
|
current_tokens = token_counter(model=model, messages=messages_copy)
|
||||||
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
|
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
|
||||||
|
|
||||||
# Do nothing if current tokens under messages
|
# Do nothing if current tokens under messages
|
||||||
if current_tokens < max_tokens:
|
if current_tokens < max_tokens:
|
||||||
return messages
|
return messages_copy
|
||||||
|
|
||||||
#### Trimming messages if current_tokens > max_tokens
|
#### Trimming messages if current_tokens > max_tokens
|
||||||
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
|
print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
|
||||||
if system_message:
|
if system_message:
|
||||||
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
|
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
|
||||||
messages = messages + [system_message_event]
|
|
||||||
|
|
||||||
final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)
|
if max_tokens == 0: # the system messages are too long
|
||||||
|
return [system_message_event]
|
||||||
|
|
||||||
|
# Since all system messages are combined and trimmed to fit the max_tokens,
|
||||||
|
# we remove all system messages from the messages list
|
||||||
|
messages_copy = [message for message in messages_copy if message["role"] != "system"]
|
||||||
|
|
||||||
|
final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model)
|
||||||
|
|
||||||
|
# Add system message to the beginning of the final messages
|
||||||
|
if system_message:
|
||||||
|
final_messages = [system_message_event] + final_messages
|
||||||
|
|
||||||
if return_response_tokens: # if user wants token count with new trimmed messages
|
if return_response_tokens: # if user wants token count with new trimmed messages
|
||||||
response_tokens = max_tokens - get_token_count(final_messages, model)
|
response_tokens = max_tokens - get_token_count(final_messages, model)
|
||||||
return final_messages, response_tokens
|
return final_messages, response_tokens
|
||||||
|
|
||||||
return final_messages
|
return final_messages
|
||||||
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
|
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
|
||||||
print_verbose(f"Got exception while token trimming{e}")
|
print_verbose(f"Got exception while token trimming{e}")
|
||||||
return messages
|
return messages_copy
|
||||||
|
|
||||||
def get_valid_models():
|
def get_valid_models():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue