fix(openai.py): deepinfra function calling - drop_params support for unsupported tool choice value

This commit is contained in:
Krrish Dholakia 2024-06-18 16:19:42 -07:00
parent 604f9689d0
commit 5ad095ad9d
3 changed files with 97 additions and 29 deletions

View file

@ -1,34 +1,41 @@
import hashlib
import json
import time
import traceback
import types
from typing import ( from typing import (
Optional,
Union,
Any, Any,
BinaryIO, BinaryIO,
Literal, Callable,
Coroutine,
Iterable, Iterable,
Literal,
Optional,
Union,
) )
import hashlib
from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx import httpx
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional, Coroutine
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
import openai import openai
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
Usage,
convert_to_model_response_object,
)
from ..types.llms.openai import *
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class OpenAIError(Exception): class OpenAIError(Exception):
@ -306,8 +313,12 @@ class DeepInfraConfig:
] ]
def map_openai_params( def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str self,
): non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_openai_params = self.get_supported_openai_params() supported_openai_params = self.get_supported_openai_params()
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if ( if (
@ -316,8 +327,23 @@ class DeepInfraConfig:
and model == "mistralai/Mistral-7B-Instruct-v0.1" and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # this model does no support temperature == 0 ): # this model does no support temperature == 0
value = 0.0001 # close to 0 value = 0.0001 # close to 0
if param == "tool_choice":
if (
value != "auto" and value != "none"
): # https://deepinfra.com/docs/advanced/function_calling
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
value = None
else:
raise litellm.utils.UnsupportedParamsError(
message="Deepinfra doesn't support tool_choice={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
value
),
status_code=400,
)
if param in supported_openai_params: if param in supported_openai_params:
optional_params[param] = value if value is not None:
optional_params[param] = value
return optional_params return optional_params

View file

@ -3355,17 +3355,54 @@ def test_completion_ai21():
# test_completion_ai21() # test_completion_ai21()
# test_completion_ai21() # test_completion_ai21()
## test deep infra ## test deep infra
def test_completion_deep_infra(): @pytest.mark.parametrize("drop_params", [True, False])
def test_completion_deep_infra(drop_params):
litellm.set_verbose = False litellm.set_verbose = False
model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf" model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try: try:
response = completion( response = completion(
model=model_name, messages=messages, temperature=0, max_tokens=10 model=model_name,
messages=messages,
temperature=0,
max_tokens=10,
tools=tools,
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
drop_params=drop_params,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") if drop_params is True:
pytest.fail(f"Error occurred: {e}")
# test_completion_deep_infra() # test_completion_deep_infra()

View file

@ -2950,6 +2950,11 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "perplexity":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(