mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(utils.py): fix trim_messages to handle tool calling
Fixes https://github.com/BerriAI/litellm/issues/4931
This commit is contained in:
parent
dd2d61bfce
commit
ae4bcd8a41
4 changed files with 100 additions and 11 deletions
|
@ -173,6 +173,71 @@ def test_trimming_with_system_message_exceeding_max_tokens():
|
|||
assert len(trimmed_messages) == 1
|
||||
|
||||
|
||||
def test_trimming_with_tool_calls():
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
|
||||
},
|
||||
Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
arguments='{"location": "San Francisco, CA", "unit": "celsius"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
id="call_G11shFcS024xEKjiAOSt6Tc9",
|
||||
type="function",
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
arguments='{"location": "Tokyo, Japan", "unit": "celsius"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
id="call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
|
||||
type="function",
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
arguments='{"location": "Paris, France", "unit": "celsius"}',
|
||||
name="get_current_weather",
|
||||
),
|
||||
id="call_nRjLXkWTJU2a4l9PZAf5as6g",
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
function_call=None,
|
||||
),
|
||||
{
|
||||
"tool_call_id": "call_G11shFcS024xEKjiAOSt6Tc9",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}',
|
||||
},
|
||||
{
|
||||
"tool_call_id": "call_e0ss43Bg7H8Z9KGdMGWyZ9Mj",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}',
|
||||
},
|
||||
{
|
||||
"tool_call_id": "call_nRjLXkWTJU2a4l9PZAf5as6g",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": '{"location": "Paris", "temperature": "22", "unit": "celsius"}',
|
||||
},
|
||||
]
|
||||
result = trim_messages(messages=messages, max_tokens=1, return_response_tokens=True)
|
||||
|
||||
print(result)
|
||||
|
||||
assert len(result[0]) == 3 # final 3 messages are tool calls
|
||||
|
||||
|
||||
def test_trimming_should_not_change_original_messages():
|
||||
messages = [
|
||||
{"role": "system", "content": "This is a short system message"},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue