mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
fix(utils.py): add support for anyscale function calling
This commit is contained in:
parent
89b808d767
commit
77b11daf28
2 changed files with 33 additions and 29 deletions
|
@ -599,34 +599,34 @@ def test_completion_hf_model_no_provider():
|
|||
|
||||
# test_completion_hf_model_no_provider()
|
||||
|
||||
# def test_completion_openai_azure_with_functions():
|
||||
# function1 = [
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
# response = completion(
|
||||
# model="azure/chatgpt-functioncalling", messages=messages, functions=function1
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_openai_azure_with_functions()
|
||||
def test_completion_anyscale_with_functions():
|
||||
function1 = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
response = completion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, functions=function1
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_anyscale_with_functions()
|
||||
|
||||
def test_completion_azure_key_completion_arg():
|
||||
# this tests if we can pass api_key to completion, when it's not in the env
|
||||
|
|
|
@ -2398,6 +2398,8 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["format"] = "json"
|
||||
litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt
|
||||
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
|
||||
elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral
|
||||
pass
|
||||
elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead
|
||||
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
|
||||
else:
|
||||
|
@ -2825,7 +2827,9 @@ def get_optional_params( # use the openai defaults
|
|||
if frequency_penalty:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
elif custom_llm_provider == "anyscale":
|
||||
supported_params = ["temperature", "top_p", "stream", "max_tokens"]
|
||||
supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"]
|
||||
if model == "mistralai/Mistral-7B-Instruct-v0.1":
|
||||
supported_params += ["functions", "function_call", "tools", "tool_choice"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = non_default_params
|
||||
if temperature is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue