mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(factory.py): option to add function details to prompt, if model doesn't support functions param
This commit is contained in:
parent
f6f7c0b891
commit
704be9dcd1
8 changed files with 130 additions and 27 deletions
|
@ -41,6 +41,7 @@ model_alias_map: Dict[str, str] = {}
|
||||||
max_budget: float = 0.0 # set the max budget across all providers
|
max_budget: float = 0.0 # set the max budget across all providers
|
||||||
_current_cost = 0 # private variable, used if max budget is set
|
_current_cost = 0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
|
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||||
#############################################
|
#############################################
|
||||||
|
|
||||||
def get_model_cost_map():
|
def get_model_cost_map():
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -200,6 +200,24 @@ def hf_chat_template(model: str, messages: list):
|
||||||
except:
|
except:
|
||||||
raise Exception("Error rendering template")
|
raise Exception("Error rendering template")
|
||||||
|
|
||||||
|
# Function call template
|
||||||
|
def function_call_prompt(messages: list, functions: list):
|
||||||
|
function_prompt = "The following functions are available to you:"
|
||||||
|
for function in functions:
|
||||||
|
function_prompt += f"""\n{function}\n"""
|
||||||
|
|
||||||
|
function_added_to_prompt = False
|
||||||
|
for message in messages:
|
||||||
|
if "system" in message["role"]:
|
||||||
|
message['content'] += f"""{function_prompt}"""
|
||||||
|
function_added_to_prompt = True
|
||||||
|
|
||||||
|
if function_added_to_prompt == False:
|
||||||
|
messages.append({'role': 'system', 'content': f"""{function_prompt}"""})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
# Custom prompt template
|
# Custom prompt template
|
||||||
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
|
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
|
||||||
prompt = initial_prompt_value
|
prompt = initial_prompt_value
|
||||||
|
@ -243,4 +261,5 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
||||||
else:
|
else:
|
||||||
return hf_chat_template(original_model_name, messages)
|
return hf_chat_template(original_model_name, messages)
|
||||||
except:
|
except:
|
||||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||||
|
|
|
@ -28,7 +28,7 @@ from litellm.utils import (
|
||||||
completion_with_fallbacks,
|
completion_with_fallbacks,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_api_key,
|
get_api_key,
|
||||||
mock_completion_streaming_obj,
|
mock_completion_streaming_obj
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic,
|
anthropic,
|
||||||
|
@ -48,7 +48,7 @@ from .llms import (
|
||||||
oobabooga,
|
oobabooga,
|
||||||
palm,
|
palm,
|
||||||
vertex_ai)
|
vertex_ai)
|
||||||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
|
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Callable, List, Optional, Dict
|
from typing import Callable, List, Optional, Dict
|
||||||
|
@ -259,27 +259,33 @@ def completion(
|
||||||
api_base = "https://proxy.litellm.ai"
|
api_base = "https://proxy.litellm.ai"
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
api_key = model_api_key
|
api_key = model_api_key
|
||||||
|
|
||||||
# check if user passed in any of the OpenAI optional params
|
# check if user passed in any of the OpenAI optional params
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
functions=functions,
|
functions=functions,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
n=n,
|
n=n,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
user=user,
|
user=user,
|
||||||
request_timeout=request_timeout,
|
request_timeout=request_timeout,
|
||||||
deployment_id=deployment_id,
|
deployment_id=deployment_id,
|
||||||
# params to identify the model
|
# params to identify the model
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params
|
**non_default_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling
|
||||||
|
functions_unsupported_model = optional_params.pop("functions_unsupported_model")
|
||||||
|
messages = function_call_prompt(messages=messages, functions=functions_unsupported_model)
|
||||||
|
|
||||||
# For logging - save the values of the litellm-specific params passed in
|
# For logging - save the values of the litellm-specific params passed in
|
||||||
litellm_params = get_litellm_params(
|
litellm_params = get_litellm_params(
|
||||||
return_async=return_async,
|
return_async=return_async,
|
||||||
|
|
75
litellm/tests/test_add_function_to_prompt.py
Normal file
75
litellm/tests/test_add_function_to_prompt.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# Allow the user to map the function to the prompt, if the model doesn't support function calling
|
||||||
|
|
||||||
|
import sys, os, pytest
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
## case 1: set_function_to_prompt not set
|
||||||
|
def test_function_call_non_openai_model():
|
||||||
|
try:
|
||||||
|
model = "claude-instant-1"
|
||||||
|
messages=[{"role": "user", "content": "what's the weather in sf?"}]
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = litellm.completion(model=model, messages=messages, functions=functions)
|
||||||
|
pytest.fail(f'An error occurred')
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
test_function_call_non_openai_model()
|
||||||
|
|
||||||
|
## case 2: add_function_to_prompt set
|
||||||
|
def test_function_call_non_openai_model_litellm_mod_set():
|
||||||
|
litellm.add_function_to_prompt = True
|
||||||
|
try:
|
||||||
|
model = "claude-instant-1"
|
||||||
|
messages=[{"role": "user", "content": "what's the weather in sf?"}]
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = litellm.completion(model=model, messages=messages, functions=functions)
|
||||||
|
print(f'response: {response}')
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f'An error occurred {e}')
|
||||||
|
|
||||||
|
# test_function_call_non_openai_model_litellm_mod_set()
|
|
@ -1001,17 +1001,20 @@ def get_optional_params( # use the openai defaults
|
||||||
}
|
}
|
||||||
# filter out those parameters that were passed with non-default values
|
# filter out those parameters that were passed with non-default values
|
||||||
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
|
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
|
||||||
|
optional_params = {}
|
||||||
## raise exception if function calling passed in for a provider that doesn't support it
|
## raise exception if function calling passed in for a provider that doesn't support it
|
||||||
if "functions" in non_default_params or "function_call" in non_default_params:
|
if "functions" in non_default_params or "function_call" in non_default_params:
|
||||||
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
|
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
|
||||||
raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider")
|
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
|
||||||
|
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
|
||||||
|
else:
|
||||||
|
raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider")
|
||||||
|
|
||||||
def _check_valid_arg(supported_params):
|
def _check_valid_arg(supported_params):
|
||||||
print_verbose(f"checking params for {model}")
|
print_verbose(f"checking params for {model}")
|
||||||
print_verbose(f"params passed in {passed_params}")
|
print_verbose(f"params passed in {passed_params}")
|
||||||
print_verbose(f"non-default params passed in {non_default_params}")
|
print_verbose(f"non-default params passed in {non_default_params}")
|
||||||
unsupported_params = []
|
unsupported_params = {}
|
||||||
for k in non_default_params.keys():
|
for k in non_default_params.keys():
|
||||||
if k not in supported_params:
|
if k not in supported_params:
|
||||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||||
|
@ -1020,12 +1023,11 @@ def get_optional_params( # use the openai defaults
|
||||||
elif k == "request_timeout": # litellm handles request time outs
|
elif k == "request_timeout": # litellm handles request time outs
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
unsupported_params.append(k)
|
unsupported_params[k] = non_default_params[k]
|
||||||
if unsupported_params and not litellm.drop_params:
|
if unsupported_params and not litellm.drop_params:
|
||||||
raise ValueError("LiteLLM.Exception: Unsupported parameters passed: {}".format(', '.join(unsupported_params)))
|
raise ValueError("LiteLLM.Exception: Unsupported parameters passed: {}".format(', '.join(unsupported_params)))
|
||||||
|
|
||||||
## raise exception if provider doesn't support passed in param
|
## raise exception if provider doesn't support passed in param
|
||||||
optional_params = {}
|
|
||||||
if custom_llm_provider == "anthropic":
|
if custom_llm_provider == "anthropic":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"]
|
supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue