mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #3062 from cwang/cwang/trim-messages-fix
Use `max_input_token` for `trim_messages`
This commit is contained in:
commit
8febe2f573
2 changed files with 20 additions and 4 deletions
|
@ -173,6 +173,22 @@ def test_trimming_should_not_change_original_messages():
|
||||||
assert messages == messages_copy
|
assert messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["gpt-4-0125-preview", "claude-3-opus-20240229"])
|
||||||
|
def test_trimming_with_model_cost_max_input_tokens(model):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "This is a normal system message"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a sentence" * 100000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
trimmed_messages = trim_messages(messages, model=model)
|
||||||
|
assert (
|
||||||
|
get_token_count(trimmed_messages, model=model)
|
||||||
|
< litellm.model_cost[model]["max_input_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_valid_models():
|
def test_get_valid_models():
|
||||||
old_environ = os.environ
|
old_environ = os.environ
|
||||||
os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ
|
os.environ = {"OPENAI_API_KEY": "temp"} # mock set only openai key in environ
|
||||||
|
|
|
@ -10592,16 +10592,16 @@ def trim_messages(
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
try:
|
try:
|
||||||
print_verbose(f"trimming messages")
|
print_verbose(f"trimming messages")
|
||||||
if max_tokens == None:
|
if max_tokens is None:
|
||||||
# Check if model is valid
|
# Check if model is valid
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
max_tokens_for_model = litellm.model_cost[model]["max_tokens"]
|
max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"])
|
||||||
max_tokens = int(max_tokens_for_model * trim_ratio)
|
max_tokens = int(max_tokens_for_model * trim_ratio)
|
||||||
else:
|
else:
|
||||||
# if user did not specify max tokens
|
# if user did not specify max (input) tokens
|
||||||
# or passed an llm litellm does not know
|
# or passed an llm litellm does not know
|
||||||
# do nothing, just return messages
|
# do nothing, just return messages
|
||||||
return
|
return messages
|
||||||
|
|
||||||
system_message = ""
|
system_message = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue