Merge pull request #3062 from cwang/cwang/trim-messages-fix

Use `max_input_token` for `trim_messages`
This commit is contained in:
Krish Dholakia 2024-04-16 22:29:45 -07:00 committed by GitHub
commit 8febe2f573
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 4 deletions

View file

@ -173,6 +173,22 @@ def test_trimming_should_not_change_original_messages():
assert messages == messages_copy assert messages == messages_copy
@pytest.mark.parametrize("model", ["gpt-4-0125-preview", "claude-3-opus-20240229"])
def test_trimming_with_model_cost_max_input_tokens(model):
messages = [
{"role": "system", "content": "This is a normal system message"},
{
"role": "user",
"content": "This is a sentence" * 100000,
},
]
trimmed_messages = trim_messages(messages, model=model)
assert (
get_token_count(trimmed_messages, model=model)
< litellm.model_cost[model]["max_input_tokens"]
)
def test_get_valid_models(): def test_get_valid_models():
old_environ = os.environ old_environ = os.environ
os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ

View file

@ -10592,16 +10592,16 @@ def trim_messages(
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
try: try:
print_verbose(f"trimming messages") print_verbose(f"trimming messages")
if max_tokens == None: if max_tokens is None:
# Check if model is valid # Check if model is valid
if model in litellm.model_cost: if model in litellm.model_cost:
max_tokens_for_model = litellm.model_cost[model]["max_tokens"] max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"])
max_tokens = int(max_tokens_for_model * trim_ratio) max_tokens = int(max_tokens_for_model * trim_ratio)
else: else:
# if user did not specify max tokens # if user did not specify max (input) tokens
# or passed an llm litellm does not know # or passed an llm litellm does not know
# do nothing, just return messages # do nothing, just return messages
return return messages
system_message = "" system_message = ""
for message in messages: for message in messages: