add token trimming to adapt to prompt flag in config

This commit is contained in:
Krrish Dholakia 2023-09-21 17:47:52 -07:00
parent 38a4a2c376
commit cf5060f136
4 changed files with 24 additions and 6 deletions

View file

@ -88,4 +88,4 @@ def test_config_context_adapt_to_prompt():
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_adapt_to_prompt() test_config_context_adapt_to_prompt()

View file

@ -2802,7 +2802,13 @@ def completion_with_config(config: Union[dict, str], **kwargs):
fallback_models = config.get("default_fallback_models", None) fallback_models = config.get("default_fallback_models", None)
available_models = config.get("available_models", None) available_models = config.get("available_models", None)
adapt_to_prompt_size = config.get("adapt_to_prompt_size", False) adapt_to_prompt_size = config.get("adapt_to_prompt_size", False)
start_time = time.time() trim_messages_flag = config.get("trim_messages", False)
prompt_larger_than_model = False
max_model = model
try:
max_tokens = litellm.get_max_tokens(model)["max_tokens"]
except:
max_tokens = 2048 # assume curr model's max window is 2048 tokens
if adapt_to_prompt_size: if adapt_to_prompt_size:
## Pick model based on token window ## Pick model based on token window
prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages))
@ -2811,14 +2817,22 @@ def completion_with_config(config: Union[dict, str], **kwargs):
except: except:
curr_max_tokens = 2048 curr_max_tokens = 2048
if curr_max_tokens < prompt_tokens: if curr_max_tokens < prompt_tokens:
prompt_larger_than_model = True
for available_model in available_models: for available_model in available_models:
try: try:
curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"]
if curr_max_tokens > max_tokens:
max_tokens = curr_max_tokens
max_model = available_model
if curr_max_tokens > prompt_tokens: if curr_max_tokens > prompt_tokens:
model = available_model model = available_model
prompt_larger_than_model = False
except: except:
continue continue
end_time = time.time() if prompt_larger_than_model:
messages = trim_messages(messages=messages, model=max_model)
kwargs["messages"] = messages
kwargs["model"] = model kwargs["model"] = model
try: try:
if model in models_with_config: if model in models_with_config:
@ -3052,8 +3066,7 @@ def shorten_message_to_fit_limit(
# Credits for this code go to Killian Lucas # Credits for this code go to Killian Lucas
def trim_messages( def trim_messages(
messages, messages,
model = None, model: Optional[str] = None,
system_message = None, # str of user system message
trim_ratio: float = 0.75, trim_ratio: float = 0.75,
return_response_tokens: bool = False, return_response_tokens: bool = False,
max_tokens = None max_tokens = None
@ -3086,6 +3099,11 @@ def trim_messages(
# do nothing, just return messages # do nothing, just return messages
return return
system_message = ""
for message in messages:
if message["role"] == "system":
system_message += message["content"]
current_tokens = token_counter(model=model, messages=messages) current_tokens = token_counter(model=model, messages=messages)
# Do nothing if current tokens under messages # Do nothing if current tokens under messages

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.731" version = "0.1.732"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"