Another small refactoring

This commit is contained in:
Duc Pham 2023-11-10 01:47:06 +07:00
parent 61f2e37349
commit c7ca8f75a2

View file

@ -4687,7 +4687,7 @@ def trim_messages(
""" """
# Initialize max_tokens # Initialize max_tokens
# if users pass in max tokens, trim to this amount # if users pass in max tokens, trim to this amount
messages_copy = copy.deepcopy(messages) messages = copy.deepcopy(messages)
try: try:
print_verbose(f"trimming messages") print_verbose(f"trimming messages")
if max_tokens == None: if max_tokens == None:
@ -4702,20 +4702,20 @@ def trim_messages(
return return
system_message = "" system_message = ""
for message in messages_copy: for message in messages:
if message["role"] == "system": if message["role"] == "system":
system_message += '\n' if system_message else '' system_message += '\n' if system_message else ''
system_message += message["content"] system_message += message["content"]
current_tokens = token_counter(model=model, messages=messages_copy) current_tokens = token_counter(model=model, messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages # Do nothing if current tokens under messages
if current_tokens < max_tokens: if current_tokens < max_tokens:
return messages_copy return messages
#### Trimming messages if current_tokens > max_tokens #### Trimming messages if current_tokens > max_tokens
print_verbose(f"Need to trim input messages: {messages_copy}, current_tokens{current_tokens}, max_tokens: {max_tokens}") print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
if system_message: if system_message:
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
@ -4724,9 +4724,9 @@ def trim_messages(
# Since all system messages are combined and trimmed to fit the max_tokens, # Since all system messages are combined and trimmed to fit the max_tokens,
# we remove all system messages from the messages list # we remove all system messages from the messages list
messages_copy = [message for message in messages_copy if message["role"] != "system"] messages = [message for message in messages if message["role"] != "system"]
final_messages = process_messages(messages=messages_copy, max_tokens=max_tokens, model=model) final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)
# Add system message to the beginning of the final messages # Add system message to the beginning of the final messages
if system_message: if system_message:
@ -4738,7 +4738,7 @@ def trim_messages(
return final_messages return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}") print_verbose(f"Got exception while token trimming{e}")
return messages_copy return messages
def get_valid_models(): def get_valid_models():
""" """