add token trimming to adapt to prompt flag in config

This commit is contained in:
Krrish Dholakia 2023-09-21 17:47:52 -07:00
parent 38a4a2c376
commit cf5060f136
4 changed files with 24 additions and 6 deletions

View file

@ -88,4 +88,4 @@ def test_config_context_adapt_to_prompt():
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_adapt_to_prompt()
test_config_context_adapt_to_prompt()

View file

@ -2802,7 +2802,13 @@ def completion_with_config(config: Union[dict, str], **kwargs):
fallback_models = config.get("default_fallback_models", None)
available_models = config.get("available_models", None)
adapt_to_prompt_size = config.get("adapt_to_prompt_size", False)
start_time = time.time()
trim_messages_flag = config.get("trim_messages", False)
prompt_larger_than_model = False
max_model = model
try:
max_tokens = litellm.get_max_tokens(model)["max_tokens"]
except:
max_tokens = 2048 # assume curr model's max window is 2048 tokens
if adapt_to_prompt_size:
## Pick model based on token window
prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages))
@ -2811,14 +2817,22 @@ def completion_with_config(config: Union[dict, str], **kwargs):
except:
curr_max_tokens = 2048
if curr_max_tokens < prompt_tokens:
prompt_larger_than_model = True
for available_model in available_models:
try:
curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"]
if curr_max_tokens > max_tokens:
max_tokens = curr_max_tokens
max_model = available_model
if curr_max_tokens > prompt_tokens:
model = available_model
prompt_larger_than_model = False
except:
continue
end_time = time.time()
if prompt_larger_than_model:
messages = trim_messages(messages=messages, model=max_model)
kwargs["messages"] = messages
kwargs["model"] = model
try:
if model in models_with_config:
@ -3052,8 +3066,7 @@ def shorten_message_to_fit_limit(
# Credits for this code go to Killian Lucas
def trim_messages(
messages,
model = None,
system_message = None, # str of user system message
model: Optional[str] = None,
trim_ratio: float = 0.75,
return_response_tokens: bool = False,
max_tokens = None
@ -3086,6 +3099,11 @@ def trim_messages(
# do nothing, just return messages
return
system_message = ""
for message in messages:
if message["role"] == "system":
system_message += message["content"]
current_tokens = token_counter(model=model, messages=messages)
# Do nothing if current tokens under messages

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.731"
version = "0.1.732"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"