feat(factory.py): option to add function details to prompt, if model doesn't support functions param

This commit is contained in:
Krrish Dholakia 2023-10-09 09:53:31 -07:00
parent f6f7c0b891
commit 704be9dcd1
8 changed files with 130 additions and 27 deletions

View file

@ -1001,17 +1001,20 @@ def get_optional_params( # use the openai defaults
}
# filter out those parameters that were passed with non-default values
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
optional_params = {}
## raise exception if function calling passed in for a provider that doesn't support it
if "functions" in non_default_params or "function_call" in non_default_params:
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider")
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
else:
raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider")
def _check_valid_arg(supported_params):
print_verbose(f"checking params for {model}")
print_verbose(f"params passed in {passed_params}")
print_verbose(f"non-default params passed in {non_default_params}")
unsupported_params = []
unsupported_params = {}
for k in non_default_params.keys():
if k not in supported_params:
if k == "n" and n == 1: # langchain sends n=1 as a default value
@ -1020,12 +1023,11 @@ def get_optional_params( # use the openai defaults
elif k == "request_timeout": # litellm handles request time outs
pass
else:
unsupported_params.append(k)
unsupported_params[k] = non_default_params[k]
if unsupported_params and not litellm.drop_params:
raise ValueError("LiteLLM.Exception: Unsupported parameters passed: {}".format(', '.join(unsupported_params)))
## raise exception if provider doesn't support passed in param
optional_params = {}
if custom_llm_provider == "anthropic":
## check if unsupported param passed in
supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"]