diff --git a/litellm/__init__.py b/litellm/__init__.py index 84520af954..dcdf2f96a4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -41,6 +41,7 @@ model_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers _current_cost = 0 # private variable, used if max budget is set error_logs: Dict = {} +add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt ############################################# def get_model_cost_map(): diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 56dc8017c9..32ca510412 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index d6d03bc5ad..6f74501e53 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index df6c1ee8b2..6f9ecddb40 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 091f8aa958..e26a846fb9 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -200,6 +200,24 @@ def hf_chat_template(model: str, messages: list): except: raise Exception("Error rendering template") +# Function call template +def function_call_prompt(messages: list, functions: list): + function_prompt = "The following functions are available to you:" + for function in functions: + function_prompt += f"""\n{function}\n""" + + function_added_to_prompt = False + for message in messages: + if "system" in message["role"]: + message['content'] += f"""{function_prompt}""" + function_added_to_prompt = True + + if function_added_to_prompt == False: + messages.append({'role': 'system', 'content': f"""{function_prompt}"""}) + + return messages + + # Custom prompt template def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""): prompt = initial_prompt_value @@ -243,4 +261,5 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str else: return hf_chat_template(original_model_name, messages) except: - return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) \ No newline at end of file + return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) + \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 35f3d7f17d..8d8930ca49 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -28,7 +28,7 @@ from litellm.utils import ( completion_with_fallbacks, get_llm_provider, get_api_key, - mock_completion_streaming_obj, + mock_completion_streaming_obj ) from .llms import ( anthropic, @@ -48,7 +48,7 @@ from .llms import ( oobabooga, palm, vertex_ai) -from .llms.prompt_templates.factory import prompt_factory, custom_prompt +from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -259,27 +259,33 @@ def completion( api_base = "https://proxy.litellm.ai" custom_llm_provider = "openai" api_key = model_api_key + # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - request_timeout=request_timeout, - deployment_id=deployment_id, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - **non_default_params - ) + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + request_timeout=request_timeout, + deployment_id=deployment_id, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + **non_default_params + ) + + if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop("functions_unsupported_model") + messages = function_call_prompt(messages=messages, functions=functions_unsupported_model) + # For logging - save the values of the litellm-specific params passed in litellm_params = get_litellm_params( return_async=return_async, diff --git a/litellm/tests/test_add_function_to_prompt.py b/litellm/tests/test_add_function_to_prompt.py new file mode 100644 index 0000000000..33d2ac2a97 --- /dev/null +++ b/litellm/tests/test_add_function_to_prompt.py @@ -0,0 +1,75 @@ +#### What this tests #### +# Allow the user to map the function to the prompt, if the model doesn't support function calling + +import sys, os, pytest +import traceback + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + +## case 1: set_function_to_prompt not set +def test_function_call_non_openai_model(): + try: + model = "claude-instant-1" + messages=[{"role": "user", "content": "what's the weather in sf?"}] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] + response = litellm.completion(model=model, messages=messages, functions=functions) + pytest.fail(f'An error occurred') + except Exception as e: + pass + +test_function_call_non_openai_model() + +## case 2: add_function_to_prompt set +def test_function_call_non_openai_model_litellm_mod_set(): + litellm.add_function_to_prompt = True + try: + model = "claude-instant-1" + messages=[{"role": "user", "content": "what's the weather in sf?"}] + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] + response = litellm.completion(model=model, messages=messages, functions=functions) + print(f'response: {response}') + except Exception as e: + pytest.fail(f'An error occurred {e}') + +# test_function_call_non_openai_model_litellm_mod_set() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 9b0c6c7356..92d2f38fb6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1001,17 +1001,20 @@ def get_optional_params( # use the openai defaults } # filter out those parameters that were passed with non-default values non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])} - + optional_params = {} ## raise exception if function calling passed in for a provider that doesn't support it if "functions" in non_default_params or "function_call" in non_default_params: if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure": - raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider") + if litellm.add_function_to_prompt: # if user opts to add it to prompt instead + optional_params["functions_unsupported_model"] = non_default_params.pop("functions") + else: + raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider") def _check_valid_arg(supported_params): print_verbose(f"checking params for {model}") print_verbose(f"params passed in {passed_params}") print_verbose(f"non-default params passed in {non_default_params}") - unsupported_params = [] + unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: if k == "n" and n == 1: # langchain sends n=1 as a default value @@ -1020,12 +1023,11 @@ def get_optional_params( # use the openai defaults elif k == "request_timeout": # litellm handles request time outs pass else: - unsupported_params.append(k) + unsupported_params[k] = non_default_params[k] if unsupported_params and not litellm.drop_params: raise ValueError("LiteLLM.Exception: Unsupported parameters passed: {}".format(', '.join(unsupported_params))) ## raise exception if provider doesn't support passed in param - optional_params = {} if custom_llm_provider == "anthropic": ## check if unsupported param passed in supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"]