diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index f59e647be..cd96f2848 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -20,9 +20,18 @@ def test_basic_trimming(): trimmed_messages = safe_messages(messages, model="claude-2", max_tokens=8) print("trimmed messages") print(trimmed_messages) - print(get_token_count(messages=trimmed_messages, model="claude-2")) + # print(get_token_count(messages=trimmed_messages, model="claude-2")) assert (get_token_count(messages=trimmed_messages, model="claude-2")) <= 8 -# test_basic_trimming() +test_basic_trimming() + +def test_basic_trimming_no_max_tokens_specified(): + messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}] + trimmed_messages = safe_messages(messages, model="gpt-4") + print("trimmed messages for gpt-4") + print(trimmed_messages) + # print(get_token_count(messages=trimmed_messages, model="claude-2")) + assert (get_token_count(messages=trimmed_messages, model="gpt-4")) <= litellm.model_cost['gpt-4']['max_tokens'] +test_basic_trimming_no_max_tokens_specified() def test_multiple_messages_trimming(): messages = [ @@ -32,9 +41,9 @@ def test_multiple_messages_trimming(): trimmed_messages = safe_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) print("Trimmed messages") print(trimmed_messages) - print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) + # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) assert(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) <= 20 -# test_multiple_messages_trimming() +test_multiple_messages_trimming() def test_multiple_messages_no_trimming(): messages = [ @@ -46,7 +55,7 @@ def test_multiple_messages_no_trimming(): print(trimmed_messages) assert(messages==trimmed_messages) -# test_multiple_messages_no_trimming() +test_multiple_messages_no_trimming() def test_large_trimming(): @@ -55,4 +64,4 @@ def test_large_trimming(): print("trimmed messages") print(trimmed_messages) assert(get_token_count(messages=trimmed_messages, model="random")) <= 20 -# test_large_trimming() \ No newline at end of file +test_large_trimming() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 2a19e885b..3a2c3f954 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -626,17 +626,26 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0): return a100_80gb_price_per_second_public*total_time -def token_counter(model, text): +def token_counter(model="", text=None, messages = None): + # Args: + # text: raw text string passed to model + # messages: List of Dicts passed to completion, messages = [{"role": "user", "content": "hello"}] # use tiktoken or anthropic's tokenizer depending on the model + if text == None: + if messages != None: + text = " ".join([message["content"] for message in messages]) num_tokens = 0 - if "claude" in model: + + if model != None and "claude" in model: try: import anthropic except Exception: - Exception("Anthropic import failed please run `pip install anthropic`") + # if importing anthropic fails + # don't raise an exception + num_tokens = len(encoding.encode(text)) + return num_tokens from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT - anthropic = Anthropic() num_tokens = anthropic.count_tokens(text) else: @@ -2352,3 +2361,144 @@ def completion_with_fallbacks(**kwargs): # print(f"rate_limited_models {rate_limited_models}") pass return response + +def process_system_message(system_message, max_tokens, model): + system_message_event = {"role": "system", "content": system_message} + system_message_tokens = get_token_count(system_message_event, model) + + if system_message_tokens > max_tokens: + print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") + # shorten system message to fit within max_tokens + new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) + system_message_tokens = get_token_count(new_system_message, model) + + return system_message_event, max_tokens - system_message_tokens + +def process_messages(messages, max_tokens, model): + # Process messages from older to more recent + messages = messages[::-1] + final_messages = [] + + for message in messages: + final_messages = attempt_message_addition(final_messages, message, max_tokens, model) + + return final_messages + +def attempt_message_addition(final_messages, message, max_tokens, model): + temp_messages = [message] + final_messages + temp_message_tokens = get_token_count(messages=temp_messages, model=model) + + if temp_message_tokens <= max_tokens: + return temp_messages + + # if temp_message_tokens > max_tokens, try shortening temp_messages + elif "function_call" not in message: + # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) + updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model) + if can_add_message(updated_message, final_messages, max_tokens, model): + return [updated_message] + final_messages + + return final_messages + +def can_add_message(message, messages, max_tokens, model): + if get_token_count(messages + [message], model) <= max_tokens: + return True + return False + +def get_token_count(messages, model): + return token_counter(model=model, messages=messages) + + +def shorten_message_to_fit_limit( + message, + tokens_needed, + model): + """ + Shorten a message to fit within a token limit by removing characters from the middle. + """ + content = message["content"] + + while True: + total_tokens = get_token_count([message], model) + + if total_tokens <= tokens_needed: + break + + ratio = (tokens_needed) / total_tokens + + new_length = int(len(content) * ratio) + print_verbose(new_length) + + half_length = new_length // 2 + left_half = content[:half_length] + right_half = content[-half_length:] + + trimmed_content = left_half + '..' + right_half + message["content"] = trimmed_content + content = trimmed_content + + return message + +# LiteLLM token trimmer +# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py +# Credits for this code go to Killian Lucas +def safe_messages( + messages, + model = None, + system_message = None, + trim_ratio: float = 0.75, + return_response_tokens: bool = False, + max_tokens = None + ): + """ + Trim a list of messages to fit within a model's token limit. + + Args: + messages: Input messages to be trimmed. Each message is a dictionary with 'role' and 'content'. + model: The LiteLLM model being used (determines the token limit). + system_message: Optional system message to preserve at the start of the conversation. + trim_ratio: Target ratio of tokens to use after trimming. Default is 0.75, meaning it will trim messages so they use about 75% of the model's token limit. + return_response_tokens: If True, also return the number of tokens left available for the response after trimming. + max_tokens: Instead of specifying a model or trim_ratio, you can specify this directly. + + Returns: + Trimmed messages and optionally the number of tokens available for response. + """ + # Initialize max_tokens + # if users pass in max tokens, trim to this amount + try: + if max_tokens == None: + # Check if model is valid + if model in litellm.model_cost: + max_tokens_for_model = litellm.model_cost[model]['max_tokens'] + max_tokens = int(max_tokens_for_model * trim_ratio) + else: + # if user did not specify max tokens + # or passed an llm litellm does not know + # do nothing, just return messages + return + + current_tokens = token_counter(model=model, messages=messages) + + # Do nothing if current tokens under messages + if current_tokens < max_tokens: + return messages + + #### Trimming messages if current_tokens > max_tokens + print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") + if system_message: + system_message_event, max_tokens = process_system_message(messages=messages, max_tokens=max_tokens, model=model) + + final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) + + if system_message: + final_messages = [system_message_event] + final_messages + + if return_response_tokens: # if user wants token count with new trimmed messages + response_tokens = max_tokens - get_token_count(final_messages, model) + return final_messages, response_tokens + + return final_messages + except: # [NON-Blocking, if error occurs just return final_messages + return messages +