Use max_input_token for trim_messages

This commit is contained in:
Chen Wang 2024-04-16 13:36:25 +01:00
parent 7ffd3d40fa
commit ebc889d77a
No known key found for this signature in database
GPG key ID: 4B04331CFE3E3BF5
2 changed files with 24 additions and 5 deletions

View file

@ -10577,16 +10577,19 @@ def trim_messages(
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
if max_tokens is None:
# Check if model is valid
if model in litellm.model_cost:
max_tokens_for_model = litellm.model_cost[model]["max_tokens"]
if (
model in litellm.model_cost
and "max_input_tokens" in litellm.model_cost[model]
):
max_tokens_for_model = litellm.model_cost[model]["max_input_tokens"]
max_tokens = int(max_tokens_for_model * trim_ratio)
else:
# if user did not specify max tokens
# if user did not specify max input tokens
# or passed an llm litellm does not know
# do nothing, just return messages
return
return messages
system_message = ""
for message in messages: