diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 0344f2114..44fb1607c 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -173,6 +173,22 @@ def test_trimming_should_not_change_original_messages(): assert messages == messages_copy +@pytest.mark.parametrize("model", ["gpt-4-0125-preview", "claude-3-opus-20240229"]) +def test_trimming_with_model_cost_max_input_tokens(model): + messages = [ + {"role": "system", "content": "This is a normal system message"}, + { + "role": "user", + "content": "This is a sentence" * 100000, + }, + ] + trimmed_messages = trim_messages(messages, model=model) + assert ( + get_token_count(trimmed_messages, model=model) + < litellm.model_cost[model]["max_input_tokens"] + ) + + def test_get_valid_models(): old_environ = os.environ os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ diff --git a/litellm/utils.py b/litellm/utils.py index 47fbffa47..dd538c7d0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10592,16 +10592,16 @@ def trim_messages( messages = copy.deepcopy(messages) try: print_verbose(f"trimming messages") - if max_tokens == None: + if max_tokens is None: # Check if model is valid if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model]["max_tokens"] + max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"]) max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max tokens + # if user did not specify max (input) tokens # or passed an llm litellm does not know # do nothing, just return messages - return + return messages system_message = "" for message in messages: