mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Improved trimming logic and OpenAI token counter
This commit is contained in:
parent
70311502c8
commit
07e8cf1d9a
2 changed files with 122 additions and 24 deletions
106
litellm/utils.py
106
litellm/utils.py
|
@ -19,6 +19,7 @@ import uuid
|
|||
import aiohttp
|
||||
import logging
|
||||
import asyncio
|
||||
import copy
|
||||
from tokenizers import Tokenizer
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
|
@ -1101,6 +1102,50 @@ def decode(model: str, tokens: List[int]):
|
|||
dec = tokenizer_json["tokenizer"].decode(tokens)
|
||||
return dec
|
||||
|
||||
def openai_token_counter(messages, model="gpt-3.5-turbo-0613"):
|
||||
"""
|
||||
Return the number of tokens used by a list of messages.
|
||||
|
||||
Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb.
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return openai_token_counter(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return openai_token_counter(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
def token_counter(model="", text=None, messages: Optional[List] = None):
|
||||
"""
|
||||
Count the number of tokens in a given text using a specified model.
|
||||
|
@ -1121,14 +1166,17 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
|
|||
raise ValueError("text and messages cannot both be None")
|
||||
num_tokens = 0
|
||||
|
||||
if model is not None:
|
||||
if model is not None:
|
||||
tokenizer_json = _select_tokenizer(model=model)
|
||||
if tokenizer_json["type"] == "huggingface_tokenizer":
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
num_tokens = len(enc.ids)
|
||||
elif tokenizer_json["type"] == "openai_tokenizer":
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
num_tokens = len(enc)
|
||||
if messages is not None:
|
||||
num_tokens = openai_token_counter(messages, model=model)
|
||||
else:
|
||||
enc = tokenizer_json["tokenizer"].encode(text)
|
||||
num_tokens = len(enc)
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text))
|
||||
return num_tokens
|
||||
|
@ -4429,7 +4477,7 @@ def completion_with_config(config: Union[dict, str], **kwargs):
|
|||
except:
|
||||
continue
|
||||
if prompt_larger_than_model:
|
||||
messages = trim_messages(messages=messages, model=max_model)
|
||||
messages = trim_messages(messages_copy=messages, model=max_model)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
kwargs["model"] = model
|
||||
|
@ -4528,13 +4576,13 @@ def completion_with_fallbacks(**kwargs):
|
|||
|
||||
def process_system_message(system_message, max_tokens, model):
|
||||
system_message_event = {"role": "system", "content": system_message}
|
||||
system_message_tokens = get_token_count(system_message_event, model)
|
||||
system_message_tokens = get_token_count([system_message_event], model)
|
||||
|
||||
if system_message_tokens > max_tokens:
|
||||
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
|
||||
# shorten system message to fit within max_tokens
|
||||
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
|
||||
system_message_tokens = get_token_count(new_system_message, model)
|
||||
system_message_tokens = get_token_count([new_system_message], model)
|
||||
|
||||
return system_message_event, max_tokens - system_message_tokens
|
||||
|
||||
|
@ -4544,11 +4592,15 @@ def process_messages(messages, max_tokens, model):
|
|||
final_messages = []
|
||||
|
||||
for message in messages:
|
||||
final_messages = attempt_message_addition(final_messages, message, max_tokens, model)
|
||||
used_tokens = get_token_count(final_messages, model)
|
||||
available_tokens = max_tokens - used_tokens
|
||||
if available_tokens <= 3:
|
||||
break
|
||||
final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model)
|
||||
|
||||
return final_messages
|
||||
|
||||
def attempt_message_addition(final_messages, message, max_tokens, model):
|
||||
def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model):
|
||||
temp_messages = [message] + final_messages
|
||||
temp_message_tokens = get_token_count(messages=temp_messages, model=model)
|
||||
|
||||
|
@ -4558,7 +4610,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model):
|
|||
# if temp_message_tokens > max_tokens, try shortening temp_messages
|
||||
elif "function_call" not in message:
|
||||
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
|
||||
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model)
|
||||
updated_message = shorten_message_to_fit_limit(message, available_tokens, model)
|
||||
if can_add_message(updated_message, final_messages, max_tokens, model):
|
||||
return [updated_message] + final_messages
|
||||
|
||||
|
@ -4580,6 +4632,13 @@ def shorten_message_to_fit_limit(
|
|||
"""
|
||||
Shorten a message to fit within a token limit by removing characters from the middle.
|
||||
"""
|
||||
|
||||
# For OpenAI models, even blank messages cost 7 token,
|
||||
# and if the buffer is less than 3, the while loop will never end,
|
||||
# hence the value 10.
|
||||
if 'gpt' in model and tokens_needed <= 10:
|
||||
return message
|
||||
|
||||
content = message["content"]
|
||||
|
||||
while True:
|
||||
|
@ -4607,7 +4666,7 @@ def shorten_message_to_fit_limit(
|
|||
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
|
||||
# Credits for this code go to Killian Lucas
|
||||
def trim_messages(
|
||||
messages,
|
||||
messages_copy,
|
||||
model: Optional[str] = None,
|
||||
trim_ratio: float = 0.75,
|
||||
return_response_tokens: bool = False,
|
||||
|
@ -4628,6 +4687,7 @@ def trim_messages(
|
|||
"""
|
||||
# Initialize max_tokens
|
||||
# if users pass in max tokens, trim to this amount
|
||||
messages_copy = copy.deepcopy(messages_copy)
|
||||
try:
|
||||
print_verbose(f"trimming messages")
|
||||
if max_tokens == None:
|
||||
|
@ -4642,33 +4702,43 @@ def trim_messages(
|
|||
return
|
||||
|
||||
system_message = ""
|
||||
for message in messages:
|
||||
for message in messages_copy:
|
||||
if message["role"] == "system":
|
||||
system_message += '\n' if system_message else ''
|
||||
system_message += message["content"]
|
||||
|
||||
current_tokens = token_counter(model=model, messages=messages)
|
||||
current_tokens = token_counter(model=model, messages=messages_copy)
|
||||
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
|
||||
|
||||
# Do nothing if current tokens under messages
|
||||
if current_tokens < max_tokens:
|
||||
return messages
|
||||
return messages_copy
|
||||
|
||||
#### Trimming messages if current_tokens > max_tokens
|
||||
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
|
||||
print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
|
||||
if system_message:
|
||||
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
|
||||
messages = messages + [system_message_event]
|
||||
|
||||
final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)
|
||||
if max_tokens == 0: # the system messages are too long
|
||||
return [system_message_event]
|
||||
|
||||
# Since all system messages are combined and trimmed to fit the max_tokens,
|
||||
# we remove all system messages from the messages list
|
||||
messages_copy = [message for message in messages_copy if message["role"] != "system"]
|
||||
|
||||
final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model)
|
||||
|
||||
# Add system message to the beginning of the final messages
|
||||
if system_message:
|
||||
final_messages = [system_message_event] + final_messages
|
||||
|
||||
if return_response_tokens: # if user wants token count with new trimmed messages
|
||||
response_tokens = max_tokens - get_token_count(final_messages, model)
|
||||
return final_messages, response_tokens
|
||||
|
||||
return final_messages
|
||||
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
|
||||
print_verbose(f"Got exception while token trimming{e}")
|
||||
return messages
|
||||
return messages_copy
|
||||
|
||||
def get_valid_models():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue