fix(utils.py): add support for anyscale function calling

This commit is contained in:
Krrish Dholakia 2023-12-20 17:48:33 +05:30
parent 89b808d767
commit 77b11daf28
2 changed files with 33 additions and 29 deletions

View file

@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider()
# def test_completion_openai_azure_with_functions():
# function1 = [
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
# },
# "required": ["location"],
# },
# }
# ]
# try:
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
# response = completion(
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_openai_azure_with_functions()
def test_completion_anyscale_with_functions():
function1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_anyscale_with_functions()
def test_completion_azure_key_completion_arg():
# this tests if we can pass api_key to completion, when it's not in the env

View file

@ -2398,6 +2398,8 @@ def get_optional_params( # use the openai defaults
optional_params["format"] = "json"
litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral
pass
elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
else:
@ -2825,7 +2827,9 @@ def get_optional_params( # use the openai defaults
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty
elif custom_llm_provider == "anyscale":
supported_params = ["temperature", "top_p", "stream", "max_tokens"]
supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"]
if model == "mistralai/Mistral-7B-Instruct-v0.1":
supported_params += ["functions", "function_call", "tools", "tool_choice"]
_check_valid_arg(supported_params=supported_params)
optional_params = non_default_params
if temperature is not None: