forked from phoenix/litellm-mirror
add safe_message trimmer
This commit is contained in:
parent
bcb89dcf4a
commit
be5a92c40a
2 changed files with 169 additions and 10 deletions
|
@ -20,9 +20,18 @@ def test_basic_trimming():
|
||||||
trimmed_messages = safe_messages(messages, model="claude-2", max_tokens=8)
|
trimmed_messages = safe_messages(messages, model="claude-2", max_tokens=8)
|
||||||
print("trimmed messages")
|
print("trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
||||||
assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8
|
assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8
|
||||||
# test_basic_trimming()
|
test_basic_trimming()
|
||||||
|
|
||||||
|
def test_basic_trimming_no_max_tokens_specified():
|
||||||
|
messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}]
|
||||||
|
trimmed_messages = safe_messages(messages, model="gpt-4")
|
||||||
|
print("trimmed messages for gpt-4")
|
||||||
|
print(trimmed_messages)
|
||||||
|
# print(get_token_count(messages=trimmed_messages, model="claude-2"))
|
||||||
|
assert (get_token_count(messages=trimmed_messages, model="gpt-4")) <= litellm.model_cost['gpt-4']['max_tokens']
|
||||||
|
test_basic_trimming_no_max_tokens_specified()
|
||||||
|
|
||||||
def test_multiple_messages_trimming():
|
def test_multiple_messages_trimming():
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -32,9 +41,9 @@ def test_multiple_messages_trimming():
|
||||||
trimmed_messages = safe_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20)
|
trimmed_messages = safe_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20)
|
||||||
print("Trimmed messages")
|
print("Trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
|
||||||
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20
|
||||||
# test_multiple_messages_trimming()
|
test_multiple_messages_trimming()
|
||||||
|
|
||||||
def test_multiple_messages_no_trimming():
|
def test_multiple_messages_no_trimming():
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -46,7 +55,7 @@ def test_multiple_messages_no_trimming():
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(messages==trimmed_messages)
|
assert(messages==trimmed_messages)
|
||||||
|
|
||||||
# test_multiple_messages_no_trimming()
|
test_multiple_messages_no_trimming()
|
||||||
|
|
||||||
|
|
||||||
def test_large_trimming():
|
def test_large_trimming():
|
||||||
|
@ -55,4 +64,4 @@ def test_large_trimming():
|
||||||
print("trimmed messages")
|
print("trimmed messages")
|
||||||
print(trimmed_messages)
|
print(trimmed_messages)
|
||||||
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
|
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
|
||||||
# test_large_trimming()
|
test_large_trimming()
|
158
litellm/utils.py
158
litellm/utils.py
|
@ -626,17 +626,26 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
|
||||||
return a100_80gb_price_per_second_public*total_time
|
return a100_80gb_price_per_second_public*total_time
|
||||||
|
|
||||||
|
|
||||||
def token_counter(model, text):
|
def token_counter(model="", text=None, messages = None):
|
||||||
|
# Args:
|
||||||
|
# text: raw text string passed to model
|
||||||
|
# messages: List of Dicts passed to completion, messages = [{"role": "user", "content": "hello"}]
|
||||||
# use tiktoken or anthropic's tokenizer depending on the model
|
# use tiktoken or anthropic's tokenizer depending on the model
|
||||||
|
if text == None:
|
||||||
|
if messages != None:
|
||||||
|
text = " ".join([message["content"] for message in messages])
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
if "claude" in model:
|
|
||||||
|
if model != None and "claude" in model:
|
||||||
try:
|
try:
|
||||||
import anthropic
|
import anthropic
|
||||||
except Exception:
|
except Exception:
|
||||||
Exception("Anthropic import failed please run `pip install anthropic`")
|
# if importing anthropic fails
|
||||||
|
# don't raise an exception
|
||||||
|
num_tokens = len(encoding.encode(text))
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||||
|
|
||||||
anthropic = Anthropic()
|
anthropic = Anthropic()
|
||||||
num_tokens = anthropic.count_tokens(text)
|
num_tokens = anthropic.count_tokens(text)
|
||||||
else:
|
else:
|
||||||
|
@ -2352,3 +2361,144 @@ def completion_with_fallbacks(**kwargs):
|
||||||
# print(f"rate_limited_models {rate_limited_models}")
|
# print(f"rate_limited_models {rate_limited_models}")
|
||||||
pass
|
pass
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def process_system_message(system_message, max_tokens, model):
|
||||||
|
system_message_event = {"role": "system", "content": system_message}
|
||||||
|
system_message_tokens = get_token_count(system_message_event, model)
|
||||||
|
|
||||||
|
if system_message_tokens > max_tokens:
|
||||||
|
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
|
||||||
|
# shorten system message to fit within max_tokens
|
||||||
|
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
|
||||||
|
system_message_tokens = get_token_count(new_system_message, model)
|
||||||
|
|
||||||
|
return system_message_event, max_tokens - system_message_tokens
|
||||||
|
|
||||||
|
def process_messages(messages, max_tokens, model):
|
||||||
|
# Process messages from older to more recent
|
||||||
|
messages = messages[::-1]
|
||||||
|
final_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
final_messages = attempt_message_addition(final_messages, message, max_tokens, model)
|
||||||
|
|
||||||
|
return final_messages
|
||||||
|
|
||||||
|
def attempt_message_addition(final_messages, message, max_tokens, model):
|
||||||
|
temp_messages = [message] + final_messages
|
||||||
|
temp_message_tokens = get_token_count(messages=temp_messages, model=model)
|
||||||
|
|
||||||
|
if temp_message_tokens <= max_tokens:
|
||||||
|
return temp_messages
|
||||||
|
|
||||||
|
# if temp_message_tokens > max_tokens, try shortening temp_messages
|
||||||
|
elif "function_call" not in message:
|
||||||
|
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
|
||||||
|
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model)
|
||||||
|
if can_add_message(updated_message, final_messages, max_tokens, model):
|
||||||
|
return [updated_message] + final_messages
|
||||||
|
|
||||||
|
return final_messages
|
||||||
|
|
||||||
|
def can_add_message(message, messages, max_tokens, model):
|
||||||
|
if get_token_count(messages + [message], model) <= max_tokens:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_token_count(messages, model):
|
||||||
|
return token_counter(model=model, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
def shorten_message_to_fit_limit(
|
||||||
|
message,
|
||||||
|
tokens_needed,
|
||||||
|
model):
|
||||||
|
"""
|
||||||
|
Shorten a message to fit within a token limit by removing characters from the middle.
|
||||||
|
"""
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
total_tokens = get_token_count([message], model)
|
||||||
|
|
||||||
|
if total_tokens <= tokens_needed:
|
||||||
|
break
|
||||||
|
|
||||||
|
ratio = (tokens_needed) / total_tokens
|
||||||
|
|
||||||
|
new_length = int(len(content) * ratio)
|
||||||
|
print_verbose(new_length)
|
||||||
|
|
||||||
|
half_length = new_length // 2
|
||||||
|
left_half = content[:half_length]
|
||||||
|
right_half = content[-half_length:]
|
||||||
|
|
||||||
|
trimmed_content = left_half + '..' + right_half
|
||||||
|
message["content"] = trimmed_content
|
||||||
|
content = trimmed_content
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
# LiteLLM token trimmer
|
||||||
|
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
|
||||||
|
# Credits for this code go to Killian Lucas
|
||||||
|
def safe_messages(
|
||||||
|
messages,
|
||||||
|
model = None,
|
||||||
|
system_message = None,
|
||||||
|
trim_ratio: float = 0.75,
|
||||||
|
return_response_tokens: bool = False,
|
||||||
|
max_tokens = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trim a list of messages to fit within a model's token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages to be trimmed. Each message is a dictionary with 'role' and 'content'.
|
||||||
|
model: The LiteLLM model being used (determines the token limit).
|
||||||
|
system_message: Optional system message to preserve at the start of the conversation.
|
||||||
|
trim_ratio: Target ratio of tokens to use after trimming. Default is 0.75, meaning it will trim messages so they use about 75% of the model's token limit.
|
||||||
|
return_response_tokens: If True, also return the number of tokens left available for the response after trimming.
|
||||||
|
max_tokens: Instead of specifying a model or trim_ratio, you can specify this directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trimmed messages and optionally the number of tokens available for response.
|
||||||
|
"""
|
||||||
|
# Initialize max_tokens
|
||||||
|
# if users pass in max tokens, trim to this amount
|
||||||
|
try:
|
||||||
|
if max_tokens == None:
|
||||||
|
# Check if model is valid
|
||||||
|
if model in litellm.model_cost:
|
||||||
|
max_tokens_for_model = litellm.model_cost[model]['max_tokens']
|
||||||
|
max_tokens = int(max_tokens_for_model * trim_ratio)
|
||||||
|
else:
|
||||||
|
# if user did not specify max tokens
|
||||||
|
# or passed an llm litellm does not know
|
||||||
|
# do nothing, just return messages
|
||||||
|
return
|
||||||
|
|
||||||
|
current_tokens = token_counter(model=model, messages=messages)
|
||||||
|
|
||||||
|
# Do nothing if current tokens under messages
|
||||||
|
if current_tokens < max_tokens:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
#### Trimming messages if current_tokens > max_tokens
|
||||||
|
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
|
||||||
|
if system_message:
|
||||||
|
system_message_event, max_tokens = process_system_message(messages=messages, max_tokens=max_tokens, model=model)
|
||||||
|
|
||||||
|
final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)
|
||||||
|
|
||||||
|
if system_message:
|
||||||
|
final_messages = [system_message_event] + final_messages
|
||||||
|
|
||||||
|
if return_response_tokens: # if user wants token count with new trimmed messages
|
||||||
|
response_tokens = max_tokens - get_token_count(final_messages, model)
|
||||||
|
return final_messages, response_tokens
|
||||||
|
|
||||||
|
return final_messages
|
||||||
|
except: # [NON-Blocking, if error occurs just return final_messages
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue