fix(azure.py): support dropping 'tool_choice=required' for older azure API versions

Closes https://github.com/BerriAI/litellm/issues/3876
This commit is contained in:
Krrish Dholakia 2024-06-01 18:44:29 -07:00
parent e149ca73f6
commit 23087295e1
3 changed files with 157 additions and 41 deletions

View file

@ -9,6 +9,7 @@ from litellm.utils import (
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse, TranscriptionResponse,
get_secret, get_secret,
UnsupportedParamsError,
) )
from typing import Callable, Optional, BinaryIO, List from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig from litellm import OpenAIConfig
@ -45,9 +46,9 @@ class AzureOpenAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig): class AzureOpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters:: The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
@ -85,18 +86,103 @@ class AzureOpenAIConfig(OpenAIConfig):
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
) -> None: ) -> None:
super().__init__( locals_ = locals().copy()
frequency_penalty, for key, value in locals_.items():
function_call, if key != "self" and value is not None:
functions, setattr(self.__class__, key, value)
logit_bias,
max_tokens, @classmethod
n, def get_config(cls):
presence_penalty, return {
stop, k: v
temperature, for k, v in cls.__dict__.items()
top_p, if not k.startswith("__")
) and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"n",
"stream",
"stop",
"max_tokens",
"tools",
"tool_choice",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
api_version: str, # Y-M-D-{optional}
) -> dict:
supported_openai_params = self.get_supported_openai_params()
api_version_times = api_version.split("-")
api_version_year = api_version_times[0]
api_version_month = api_version_times[1]
api_version_day = api_version_times[2]
args = locals()
for param, value in non_default_params.items():
if param == "tool_choice":
"""
This parameter requires API version 2023-12-01-preview or later
tool_choice='required' is not supported as of 2024-05-01-preview
"""
## check if api version supports this param ##
if (
api_version_year < "2023"
or (api_version_year == "2023" and api_version_month < "12")
or (
api_version_year == "2023"
and api_version_month == "12"
and api_version_day < "01"
)
):
if litellm.drop_params == False:
raise UnsupportedParamsError(
status_code=400,
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
)
elif value == "required" and (
api_version_year == "2024" and api_version_month <= "05"
): ## check if tool_choice value is supported ##
if litellm.drop_params == False:
raise UnsupportedParamsError(
status_code=400,
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
)
else:
optional_params["tool_choice"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict: def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"} return {"token": "azure_ad_token"}
@ -172,9 +258,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
possible_azure_ad_token = req_token.json().get("access_token", None) possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None: if possible_azure_ad_token is None:
raise AzureOpenAIError( raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token return possible_azure_ad_token
@ -245,7 +329,9 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token

View file

@ -207,3 +207,39 @@ def test_openai_extra_headers():
assert optional_params["max_tokens"] == 10 assert optional_params["max_tokens"] == 10
assert optional_params["temperature"] == 0.2 assert optional_params["temperature"] == 0.2
assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"} assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"}
@pytest.mark.parametrize(
"api_version",
[
"2024-02-01",
"2024-07-01", # potential future version with tool_choice="required" supported
"2023-07-01-preview",
"2024-03-01-preview",
],
)
def test_azure_tool_choice(api_version):
"""
Test azure tool choice on older + new version
"""
litellm.drop_params = True
optional_params = litellm.utils.get_optional_params(
model="chatgpt-v-2",
user="John",
custom_llm_provider="azure",
max_tokens=10,
temperature=0.2,
extra_headers={"AI-Resource Group": "ishaan-resource"},
tool_choice="required",
api_version=api_version,
)
print(f"{optional_params}")
if api_version == "2024-07-01":
assert optional_params["tool_choice"] == "required"
else:
assert (
"tool_choice" not in optional_params
), "tool_choice={} for api version={}".format(
optional_params["tool_choice"], api_version
)

View file

@ -6045,6 +6045,22 @@ def get_optional_params(
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
) )
elif custom_llm_provider == "azure":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="azure"
)
_check_valid_arg(supported_params=supported_params)
api_version = (
passed_params.get("api_version", None)
or litellm.api_version
or get_secret("AZURE_API_VERSION")
)
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
)
else: # assume passing in params for azure openai else: # assume passing in params for azure openai
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="azure" model=model, custom_llm_provider="azure"
@ -6481,29 +6497,7 @@ def get_supported_openai_params(
elif custom_llm_provider == "openai": elif custom_llm_provider == "openai":
return litellm.OpenAIConfig().get_supported_openai_params(model=model) return litellm.OpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
return [ return litellm.AzureOpenAIConfig().get_supported_openai_params()
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
return [ return [
"functions", "functions",