rename safe_messages to trim_messages

This commit is contained in:
ishaan-jaff 2023-09-11 18:54:18 -07:00
parent 7d41e701a6
commit dda87a55ff
3 changed files with 12 additions and 12 deletions

View file

@ -1,25 +1,25 @@
# Trimming Input Messages # Trimming Input Messages
**Use litellm.safe_messages() to ensure messages does not exceed a model's token limit or specified `max_tokens`** **Use litellm.trim_messages() to ensure messages does not exceed a model's token limit or specified `max_tokens`**
## Usage ## Usage
```python ```python
from litellm import completion from litellm import completion
from litellm.utils import safe_messages from litellm.utils import trim_messages
response = completion( response = completion(
model=model, model=model,
messages=safe_messages(messages, model) # safe_messages ensures tokens(messages) < max_tokens(model) messages=trim_messages(messages, model) # trim_messages ensures tokens(messages) < max_tokens(model)
) )
``` ```
## Usage - set max_tokens ## Usage - set max_tokens
```python ```python
from litellm import completion from litellm import completion
from litellm.utils import safe_messages from litellm.utils import trim_messages
response = completion( response = completion(
model=model, model=model,
messages=safe_messages(messages, model, max_tokens=10), # safe_messages ensures tokens(messages) < max_tokens messages=trim_messages(messages, model, max_tokens=10), # trim_messages ensures tokens(messages) < max_tokens
) )
``` ```

View file

@ -10,14 +10,14 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm.utils import safe_messages, get_token_count from litellm.utils import trim_messages, get_token_count
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' # Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
# Test 1: Check trimming of normal message # Test 1: Check trimming of normal message
def test_basic_trimming(): def test_basic_trimming():
messages = [{"role": "user", "content": "This is a long message that definitely exceeds the token limit."}] messages = [{"role": "user", "content": "This is a long message that definitely exceeds the token limit."}]
trimmed_messages = safe_messages(messages, model="claude-2", max_tokens=8) trimmed_messages = trim_messages(messages, model="claude-2", max_tokens=8)
print("trimmed messages") print("trimmed messages")
print(trimmed_messages) print(trimmed_messages)
# print(get_token_count(messages=trimmed_messages, model="claude-2")) # print(get_token_count(messages=trimmed_messages, model="claude-2"))
@ -26,7 +26,7 @@ test_basic_trimming()
def test_basic_trimming_no_max_tokens_specified(): def test_basic_trimming_no_max_tokens_specified():
messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}] messages = [{"role": "user", "content": "This is a long message that is definitely under the token limit."}]
trimmed_messages = safe_messages(messages, model="gpt-4") trimmed_messages = trim_messages(messages, model="gpt-4")
print("trimmed messages for gpt-4") print("trimmed messages for gpt-4")
print(trimmed_messages) print(trimmed_messages)
# print(get_token_count(messages=trimmed_messages, model="claude-2")) # print(get_token_count(messages=trimmed_messages, model="claude-2"))
@ -38,7 +38,7 @@ def test_multiple_messages_trimming():
{"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is a long message that will exceed the token limit."},
{"role": "user", "content": "This is another long message that will also exceed the limit."} {"role": "user", "content": "This is another long message that will also exceed the limit."}
] ]
trimmed_messages = safe_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20) trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=20)
print("Trimmed messages") print("Trimmed messages")
print(trimmed_messages) print(trimmed_messages)
# print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo")) # print(get_token_count(messages=trimmed_messages, model="gpt-3.5-turbo"))
@ -50,7 +50,7 @@ def test_multiple_messages_no_trimming():
{"role": "user", "content": "This is a long message that will exceed the token limit."}, {"role": "user", "content": "This is a long message that will exceed the token limit."},
{"role": "user", "content": "This is another long message that will also exceed the limit."} {"role": "user", "content": "This is another long message that will also exceed the limit."}
] ]
trimmed_messages = safe_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100) trimmed_messages = trim_messages(messages=messages, model="gpt-3.5-turbo", max_tokens=100)
print("Trimmed messages") print("Trimmed messages")
print(trimmed_messages) print(trimmed_messages)
assert(messages==trimmed_messages) assert(messages==trimmed_messages)
@ -60,7 +60,7 @@ test_multiple_messages_no_trimming()
def test_large_trimming(): def test_large_trimming():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}] messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = safe_messages(messages, max_tokens=20, model="random") trimmed_messages = trim_messages(messages, max_tokens=20, model="random")
print("trimmed messages") print("trimmed messages")
print(trimmed_messages) print(trimmed_messages)
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20 assert(get_token_count(messages=trimmed_messages, model="random")) <= 20

View file

@ -2496,7 +2496,7 @@ def shorten_message_to_fit_limit(
# LiteLLM token trimmer # LiteLLM token trimmer
# this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py
# Credits for this code go to Killian Lucas # Credits for this code go to Killian Lucas
def safe_messages( def trim_messages(
messages, messages,
model = None, model = None,
system_message = None, # str of user system message system_message = None, # str of user system message