From ebc889d77a177f360592167839d59180456bd8f8 Mon Sep 17 00:00:00 2001 From: Chen Wang Date: Tue, 16 Apr 2024 13:36:25 +0100 Subject: [PATCH 1/2] Use `max_input_token` for `trim_messages` --- litellm/tests/test_utils.py | 16 ++++++++++++++++ litellm/utils.py | 13 ++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 0344f21146..44fb1607c3 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -173,6 +173,22 @@ def test_trimming_should_not_change_original_messages(): assert messages == messages_copy +@pytest.mark.parametrize("model", ["gpt-4-0125-preview", "claude-3-opus-20240229"]) +def test_trimming_with_model_cost_max_input_tokens(model): + messages = [ + {"role": "system", "content": "This is a normal system message"}, + { + "role": "user", + "content": "This is a sentence" * 100000, + }, + ] + trimmed_messages = trim_messages(messages, model=model) + assert ( + get_token_count(trimmed_messages, model=model) + < litellm.model_cost[model]["max_input_tokens"] + ) + + def test_get_valid_models(): old_environ = os.environ os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ diff --git a/litellm/utils.py b/litellm/utils.py index 3a319f24b0..c497c63268 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10577,16 +10577,19 @@ def trim_messages( messages = copy.deepcopy(messages) try: print_verbose(f"trimming messages") - if max_tokens == None: + if max_tokens is None: # Check if model is valid - if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model]["max_tokens"] + if ( + model in litellm.model_cost + and "max_input_tokens" in litellm.model_cost[model] + ): + max_tokens_for_model = litellm.model_cost[model]["max_input_tokens"] max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max tokens + # if user did not specify max input tokens # or passed an llm litellm does not know # do nothing, just return messages - return + return messages system_message = "" for message in messages: From 38c61a23b4d024dd46dab69905f1a885fd427112 Mon Sep 17 00:00:00 2001 From: Chen Wang Date: Tue, 16 Apr 2024 19:00:09 +0100 Subject: [PATCH 2/2] Fall back to `max_tokens` --- litellm/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index c497c63268..3260b1e157 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10579,14 +10579,11 @@ def trim_messages( print_verbose(f"trimming messages") if max_tokens is None: # Check if model is valid - if ( - model in litellm.model_cost - and "max_input_tokens" in litellm.model_cost[model] - ): - max_tokens_for_model = litellm.model_cost[model]["max_input_tokens"] + if model in litellm.model_cost: + max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"]) max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max input tokens + # if user did not specify max (input) tokens # or passed an llm litellm does not know # do nothing, just return messages return messages