add adapt to prompt size to config

This commit is contained in:
Krrish Dholakia 2023-09-21 15:07:39 -07:00
parent ff98af9b0c
commit be00dd6367
4 changed files with 62 additions and 11 deletions

View file

@ -2772,7 +2772,7 @@ def read_config_args(config_path) -> dict:
########## experimental completion variants ############################
def completion_with_config(*args, config: Union[dict, str], **kwargs):
def completion_with_config(*, config: Union[dict, str], **kwargs):
if config is not None:
if isinstance(config, str):
config = read_config_args(config)
@ -2793,11 +2793,31 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs):
raise Exception("No completion config in the config file")
models_with_config = completion_config["model"].keys()
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
model = kwargs["model"]
messages = kwargs["messages"]
## Default fallback models
fallback_models = completion_config.get("default_fallback_models")
## completion config
fallback_models = completion_config.get("default_fallback_models", None)
available_models = completion_config.get("available_models", None)
adapt_to_prompt_size = completion_config.get("adapt_to_prompt_size", False)
start_time = time.time()
if adapt_to_prompt_size:
## Pick model based on token window
prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages))
try:
curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"]
except:
curr_max_tokens = 2048
if curr_max_tokens < prompt_tokens:
for available_model in available_models:
try:
curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"]
if curr_max_tokens > prompt_tokens:
model = available_model
except:
continue
end_time = time.time()
kwargs["model"] = model
try:
if model in models_with_config:
## Moderation check
@ -2814,7 +2834,7 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs):
error_handling = completion_config["model"][model]["error_handling"]
try:
response = litellm.completion(*args, **kwargs)
response = litellm.completion(**kwargs)
return response
except Exception as e:
exception_name = type(e).__name__
@ -2825,10 +2845,10 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs):
fallback_model = error_handler.get("fallback_model", None)
if fallback_model:
kwargs["model"] = fallback_model
return litellm.completion(*args, **kwargs)
return litellm.completion(**kwargs)
raise e
else:
return litellm.completion(*args, **kwargs)
return litellm.completion(**kwargs)
except Exception as e:
if fallback_models:
model = fallback_models.pop(0)
@ -2933,7 +2953,7 @@ def completion_with_fallbacks(**kwargs):
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
response = litellm.completion(**kwargs, model=model)
if response != None: