fix(utils.py): fix trim_messages to handle tool calling

Fixes https://github.com/BerriAI/litellm/issues/4931
This commit is contained in:
Krrish Dholakia 2024-07-29 13:04:41 -07:00
parent dd2d61bfce
commit ae4bcd8a41
4 changed files with 100 additions and 11 deletions

View file

@ -1,12 +1,12 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: mypy # - id: mypy
name: mypy # name: mypy
entry: python3 -m mypy --ignore-missing-imports # entry: python3 -m mypy --ignore-missing-imports
language: system # language: system
types: [python] # types: [python]
files: ^litellm/ # files: ^litellm/
- id: isort - id: isort
name: isort name: isort
entry: isort entry: isort

View file

@ -173,6 +173,71 @@ def test_trimming_with_system_message_exceeding_max_tokens():
assert len(trimmed_messages) == 1 assert len(trimmed_messages) == 1
def test_trimming_with_tool_calls():
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
},
Message(
content=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "San Francisco, CA", "unit": "celsius"}',
name="get_current_weather",
),
id="call_G11shFcS024xEKjiAOSt6Tc9",
type="function",
),
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "Tokyo, Japan", "unit": "celsius"}',
name="get_current_weather",
),
id="call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
type="function",
),
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "Paris, France", "unit": "celsius"}',
name="get_current_weather",
),
id="call_nRjLXkWTJU2a4l9PZAf5as6g",
type="function",
),
],
function_call=None,
),
{
"tool_call_id": "call_G11shFcS024xEKjiAOSt6Tc9",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
},
{
"tool_call_id": "call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}',
},
{
"tool_call_id": "call_nRjLXkWTJU2a4l9PZAf5as6g",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "Paris", "temperature": "22", "unit": "celsius"}',
},
]
result = trim_messages(messages=messages, max_tokens=1, return_response_tokens=True)
print(result)
assert len(result[0]) == 3 # final 3 messages are tool calls
def test_trimming_should_not_change_original_messages(): def test_trimming_should_not_change_original_messages():
messages = [ messages = [
{"role": "system", "content": "This is a short system message"}, {"role": "system", "content": "This is a short system message"},

View file

@ -312,7 +312,14 @@ class Message(OpenAIObject):
FunctionCall(**function_call) if function_call is not None else None FunctionCall(**function_call) if function_call is not None else None
), ),
"tool_calls": ( "tool_calls": (
[ChatCompletionMessageToolCall(**tool_call) for tool_call in tool_calls] [
(
ChatCompletionMessageToolCall(**tool_call)
if isinstance(tool_call, dict)
else tool_call
)
for tool_call in tool_calls
]
if tool_calls is not None if tool_calls is not None
else None else None
), ),

View file

@ -10658,7 +10658,7 @@ def get_token_count(messages, model):
return token_counter(model=model, messages=messages) return token_counter(model=model, messages=messages)
def shorten_message_to_fit_limit(message, tokens_needed, model): def shorten_message_to_fit_limit(message, tokens_needed, model: Optional[str]):
""" """
Shorten a message to fit within a token limit by removing characters from the middle. Shorten a message to fit within a token limit by removing characters from the middle.
""" """
@ -10666,7 +10666,7 @@ def shorten_message_to_fit_limit(message, tokens_needed, model):
# For OpenAI models, even blank messages cost 7 token, # For OpenAI models, even blank messages cost 7 token,
# and if the buffer is less than 3, the while loop will never end, # and if the buffer is less than 3, the while loop will never end,
# hence the value 10. # hence the value 10.
if "gpt" in model and tokens_needed <= 10: if model is not None and "gpt" in model and tokens_needed <= 10:
return message return message
content = message["content"] content = message["content"]
@ -10720,7 +10720,6 @@ def trim_messages(
# if users pass in max tokens, trim to this amount # if users pass in max tokens, trim to this amount
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
try: try:
print_verbose(f"trimming messages")
if max_tokens is None: if max_tokens is None:
# Check if model is valid # Check if model is valid
if model in litellm.model_cost: if model in litellm.model_cost:
@ -10740,6 +10739,17 @@ def trim_messages(
system_message += "\n" if system_message else "" system_message += "\n" if system_message else ""
system_message += message["content"] system_message += message["content"]
## Handle Tool Call ## - check if last message is a tool response, return as is - https://github.com/BerriAI/litellm/issues/4931
tool_messages = []
for message in reversed(messages):
if message["role"] != "tool":
break
tool_messages.append(message)
# # Remove the collected tool messages from the original list
if len(tool_messages):
messages = messages[: -len(tool_messages)]
current_tokens = token_counter(model=model, messages=messages) current_tokens = token_counter(model=model, messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
@ -10771,6 +10781,9 @@ def trim_messages(
if system_message: if system_message:
final_messages = [system_message_event] + final_messages final_messages = [system_message_event] + final_messages
if len(tool_messages) > 0:
final_messages.extend(tool_messages)
if ( if (
return_response_tokens return_response_tokens
): # if user wants token count with new trimmed messages ): # if user wants token count with new trimmed messages
@ -10778,7 +10791,11 @@ def trim_messages(
return final_messages, response_tokens return final_messages, response_tokens
return final_messages return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}") verbose_logger.error(
"Got exception while token trimming - {}\n{}".format(
str(e), traceback.format_exc()
)
)
return messages return messages