Another small refactoring

This commit is contained in:
Duc Pham 2023-11-10 01:47:06 +07:00
parent eeac3954d5
commit 8e13da198c

View file

@ -4687,7 +4687,7 @@ def trim_messages(
"""
# Initialize max_tokens
# if users pass in max tokens, trim to this amount
messages_copy = copy.deepcopy(messages)
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
@ -4702,20 +4702,20 @@ def trim_messages(
return
system_message = ""
for message in messages_copy:
for message in messages:
if message["role"] == "system":
system_message += '\n' if system_message else ''
system_message += message["content"]
current_tokens = token_counter(model=model, messages=messages_copy)
current_tokens = token_counter(model=model, messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages
if current_tokens < max_tokens:
return messages_copy
return messages
#### Trimming messages if current_tokens > max_tokens
print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
if system_message:
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
@ -4724,9 +4724,9 @@ def trim_messages(
# Since all system messages are combined and trimmed to fit the max_tokens,
# we remove all system messages from the messages list
messages_copy = [message for message in messages_copy if message["role"] != "system"]
messages = [message for message in messages if message["role"] != "system"]
final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model)
final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)
# Add system message to the beginning of the final messages
if system_message:
@ -4738,7 +4738,7 @@ def trim_messages(
return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}")
return messages_copy
return messages
def get_valid_models():
"""