forked from phoenix/litellm-mirror
add token trimming to adapt to prompt flag in config
This commit is contained in:
parent
38a4a2c376
commit
cf5060f136
4 changed files with 24 additions and 6 deletions
Binary file not shown.
|
@ -88,4 +88,4 @@ def test_config_context_adapt_to_prompt():
|
|||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# test_config_context_adapt_to_prompt()
|
||||
test_config_context_adapt_to_prompt()
|
|
@ -2802,7 +2802,13 @@ def completion_with_config(config: Union[dict, str], **kwargs):
|
|||
fallback_models = config.get("default_fallback_models", None)
|
||||
available_models = config.get("available_models", None)
|
||||
adapt_to_prompt_size = config.get("adapt_to_prompt_size", False)
|
||||
start_time = time.time()
|
||||
trim_messages_flag = config.get("trim_messages", False)
|
||||
prompt_larger_than_model = False
|
||||
max_model = model
|
||||
try:
|
||||
max_tokens = litellm.get_max_tokens(model)["max_tokens"]
|
||||
except:
|
||||
max_tokens = 2048 # assume curr model's max window is 2048 tokens
|
||||
if adapt_to_prompt_size:
|
||||
## Pick model based on token window
|
||||
prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages))
|
||||
|
@ -2811,14 +2817,22 @@ def completion_with_config(config: Union[dict, str], **kwargs):
|
|||
except:
|
||||
curr_max_tokens = 2048
|
||||
if curr_max_tokens < prompt_tokens:
|
||||
prompt_larger_than_model = True
|
||||
for available_model in available_models:
|
||||
try:
|
||||
curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"]
|
||||
if curr_max_tokens > max_tokens:
|
||||
max_tokens = curr_max_tokens
|
||||
max_model = available_model
|
||||
if curr_max_tokens > prompt_tokens:
|
||||
model = available_model
|
||||
prompt_larger_than_model = False
|
||||
except:
|
||||
continue
|
||||
end_time = time.time()
|
||||
if prompt_larger_than_model:
|
||||
messages = trim_messages(messages=messages, model=max_model)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
kwargs["model"] = model
|
||||
try:
|
||||
if model in models_with_config:
|
||||
|
@ -3052,8 +3066,7 @@ def shorten_message_to_fit_limit(
|
|||
# Credits for this code go to Killian Lucas
|
||||
def trim_messages(
|
||||
messages,
|
||||
model = None,
|
||||
system_message = None, # str of user system message
|
||||
model: Optional[str] = None,
|
||||
trim_ratio: float = 0.75,
|
||||
return_response_tokens: bool = False,
|
||||
max_tokens = None
|
||||
|
@ -3086,6 +3099,11 @@ def trim_messages(
|
|||
# do nothing, just return messages
|
||||
return
|
||||
|
||||
system_message = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_message += message["content"]
|
||||
|
||||
current_tokens = token_counter(model=model, messages=messages)
|
||||
|
||||
# Do nothing if current tokens under messages
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.731"
|
||||
version = "0.1.732"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue