mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): fix trim_messages to handle tool calling
Fixes https://github.com/BerriAI/litellm/issues/4931
This commit is contained in:
parent
dd2d61bfce
commit
ae4bcd8a41
4 changed files with 100 additions and 11 deletions
|
@ -10658,7 +10658,7 @@ def get_token_count(messages, model):
|
|||
return token_counter(model=model, messages=messages)
|
||||
|
||||
|
||||
def shorten_message_to_fit_limit(message, tokens_needed, model):
|
||||
def shorten_message_to_fit_limit(message, tokens_needed, model: Optional[str]):
|
||||
"""
|
||||
Shorten a message to fit within a token limit by removing characters from the middle.
|
||||
"""
|
||||
|
@ -10666,7 +10666,7 @@ def shorten_message_to_fit_limit(message, tokens_needed, model):
|
|||
# For OpenAI models, even blank messages cost 7 token,
|
||||
# and if the buffer is less than 3, the while loop will never end,
|
||||
# hence the value 10.
|
||||
if "gpt" in model and tokens_needed <= 10:
|
||||
if model is not None and "gpt" in model and tokens_needed <= 10:
|
||||
return message
|
||||
|
||||
content = message["content"]
|
||||
|
@ -10720,7 +10720,6 @@ def trim_messages(
|
|||
# if users pass in max tokens, trim to this amount
|
||||
messages = copy.deepcopy(messages)
|
||||
try:
|
||||
print_verbose(f"trimming messages")
|
||||
if max_tokens is None:
|
||||
# Check if model is valid
|
||||
if model in litellm.model_cost:
|
||||
|
@ -10740,6 +10739,17 @@ def trim_messages(
|
|||
system_message += "\n" if system_message else ""
|
||||
system_message += message["content"]
|
||||
|
||||
## Handle Tool Call ## - check if last message is a tool response, return as is - https://github.com/BerriAI/litellm/issues/4931
|
||||
tool_messages = []
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] != "tool":
|
||||
break
|
||||
tool_messages.append(message)
|
||||
# # Remove the collected tool messages from the original list
|
||||
if len(tool_messages):
|
||||
messages = messages[: -len(tool_messages)]
|
||||
|
||||
current_tokens = token_counter(model=model, messages=messages)
|
||||
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
|
||||
|
||||
|
@ -10771,6 +10781,9 @@ def trim_messages(
|
|||
if system_message:
|
||||
final_messages = [system_message_event] + final_messages
|
||||
|
||||
if len(tool_messages) > 0:
|
||||
final_messages.extend(tool_messages)
|
||||
|
||||
if (
|
||||
return_response_tokens
|
||||
): # if user wants token count with new trimmed messages
|
||||
|
@ -10778,7 +10791,11 @@ def trim_messages(
|
|||
return final_messages, response_tokens
|
||||
return final_messages
|
||||
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
|
||||
print_verbose(f"Got exception while token trimming{e}")
|
||||
verbose_logger.error(
|
||||
"Got exception while token trimming - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue