fix(utils.py): fix trim_messages to handle tool calling

Fixes https://github.com/BerriAI/litellm/issues/4931
This commit is contained in:
Krrish Dholakia 2024-07-29 13:04:41 -07:00
parent dd2d61bfce
commit ae4bcd8a41
4 changed files with 100 additions and 11 deletions

View file

@ -173,6 +173,71 @@ def test_trimming_with_system_message_exceeding_max_tokens():
assert len(trimmed_messages) == 1
def test_trimming_with_tool_calls():
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
},
Message(
content=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "San Francisco, CA", "unit": "celsius"}',
name="get_current_weather",
),
id="call_G11shFcS024xEKjiAOSt6Tc9",
type="function",
),
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "Tokyo, Japan", "unit": "celsius"}',
name="get_current_weather",
),
id="call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
type="function",
),
ChatCompletionMessageToolCall(
function=Function(
arguments='{"location": "Paris, France", "unit": "celsius"}',
name="get_current_weather",
),
id="call_nRjLXkWTJU2a4l9PZAf5as6g",
type="function",
),
],
function_call=None,
),
{
"tool_call_id": "call_G11shFcS024xEKjiAOSt6Tc9",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
},
{
"tool_call_id": "call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}',
},
{
"tool_call_id": "call_nRjLXkWTJU2a4l9PZAf5as6g",
"role": "tool",
"name": "get_current_weather",
"content": '{"location": "Paris", "temperature": "22", "unit": "celsius"}',
},
]
result = trim_messages(messages=messages, max_tokens=1, return_response_tokens=True)
print(result)
assert len(result[0]) == 3 # final 3 messages are tool calls
def test_trimming_should_not_change_original_messages():
messages = [
{"role": "system", "content": "This is a short system message"},