fix(utils.py): add support for anyscale function calling

This commit is contained in:
Krrish Dholakia 2023-12-20 17:48:33 +05:30
parent 89b808d767
commit 77b11daf28
2 changed files with 33 additions and 29 deletions

View file

@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider() # test_completion_hf_model_no_provider()
# def test_completion_openai_azure_with_functions(): def test_completion_anyscale_with_functions():
# function1 = [ function1 = [
# { {
# "name": "get_current_weather", "name": "get_current_weather",
# "description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
# "parameters": { "parameters": {
# "type": "object", "type": "object",
# "properties": { "properties": {
# "location": { "location": {
# "type": "string", "type": "string",
# "description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
# }, },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# }, },
# "required": ["location"], "required": ["location"],
# }, },
# } }
# ] ]
# try: try:
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
# response = completion( response = completion(
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1 model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
# ) )
# # Add any assertions here to check the response # Add any assertions here to check the response
# print(response) print(response)
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_azure_with_functions() test_completion_anyscale_with_functions()
def test_completion_azure_key_completion_arg(): def test_completion_azure_key_completion_arg():
# this tests if we can pass api_key to completion, when it's not in the env # this tests if we can pass api_key to completion, when it's not in the env

View file

@ -2398,6 +2398,8 @@ def get_optional_params( # use the openai defaults
optional_params["format"] = "json" optional_params["format"] = "json"
litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral
pass
elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
else: else:
@ -2825,7 +2827,9 @@ def get_optional_params( # use the openai defaults
if frequency_penalty: if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty optional_params["frequency_penalty"] = frequency_penalty
elif custom_llm_provider == "anyscale": elif custom_llm_provider == "anyscale":
supported_params = ["temperature", "top_p", "stream", "max_tokens"] supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"]
if model == "mistralai/Mistral-7B-Instruct-v0.1":
supported_params += ["functions", "function_call", "tools", "tool_choice"]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = non_default_params optional_params = non_default_params
if temperature is not None: if temperature is not None: