mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
fix(azure.py): support dropping 'tool_choice=required' for older azure API versions
Closes https://github.com/BerriAI/litellm/issues/3876
This commit is contained in:
parent
e149ca73f6
commit
23087295e1
3 changed files with 157 additions and 41 deletions
|
@ -9,6 +9,7 @@ from litellm.utils import (
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
get_secret,
|
get_secret,
|
||||||
|
UnsupportedParamsError,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional, BinaryIO, List
|
from typing import Callable, Optional, BinaryIO, List
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
|
@ -45,9 +46,9 @@ class AzureOpenAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIConfig(OpenAIConfig):
|
class AzureOpenAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||||
|
|
||||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||||
|
|
||||||
|
@ -85,18 +86,103 @@ class AzureOpenAIConfig(OpenAIConfig):
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
locals_ = locals().copy()
|
||||||
frequency_penalty,
|
for key, value in locals_.items():
|
||||||
function_call,
|
if key != "self" and value is not None:
|
||||||
functions,
|
setattr(self.__class__, key, value)
|
||||||
logit_bias,
|
|
||||||
max_tokens,
|
@classmethod
|
||||||
n,
|
def get_config(cls):
|
||||||
presence_penalty,
|
return {
|
||||||
stop,
|
k: v
|
||||||
temperature,
|
for k, v in cls.__dict__.items()
|
||||||
top_p,
|
if not k.startswith("__")
|
||||||
)
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"temperature",
|
||||||
|
"n",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"logit_bias",
|
||||||
|
"user",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"top_p",
|
||||||
|
"log_probs",
|
||||||
|
"top_logprobs",
|
||||||
|
"response_format",
|
||||||
|
"seed",
|
||||||
|
"extra_headers",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
api_version: str, # Y-M-D-{optional}
|
||||||
|
) -> dict:
|
||||||
|
supported_openai_params = self.get_supported_openai_params()
|
||||||
|
|
||||||
|
api_version_times = api_version.split("-")
|
||||||
|
api_version_year = api_version_times[0]
|
||||||
|
api_version_month = api_version_times[1]
|
||||||
|
api_version_day = api_version_times[2]
|
||||||
|
args = locals()
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "tool_choice":
|
||||||
|
"""
|
||||||
|
This parameter requires API version 2023-12-01-preview or later
|
||||||
|
|
||||||
|
tool_choice='required' is not supported as of 2024-05-01-preview
|
||||||
|
"""
|
||||||
|
## check if api version supports this param ##
|
||||||
|
if (
|
||||||
|
api_version_year < "2023"
|
||||||
|
or (api_version_year == "2023" and api_version_month < "12")
|
||||||
|
or (
|
||||||
|
api_version_year == "2023"
|
||||||
|
and api_version_month == "12"
|
||||||
|
and api_version_day < "01"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if litellm.drop_params == False:
|
||||||
|
raise UnsupportedParamsError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
|
||||||
|
)
|
||||||
|
elif value == "required" and (
|
||||||
|
api_version_year == "2024" and api_version_month <= "05"
|
||||||
|
): ## check if tool_choice value is supported ##
|
||||||
|
if litellm.drop_params == False:
|
||||||
|
raise UnsupportedParamsError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optional_params["tool_choice"] = value
|
||||||
|
elif param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def get_mapped_special_auth_params(self) -> dict:
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
return {"token": "azure_ad_token"}
|
return {"token": "azure_ad_token"}
|
||||||
|
@ -172,9 +258,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||||
|
|
||||||
if possible_azure_ad_token is None:
|
if possible_azure_ad_token is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")
|
||||||
status_code=422, message="Azure AD Token not returned"
|
|
||||||
)
|
|
||||||
|
|
||||||
return possible_azure_ad_token
|
return possible_azure_ad_token
|
||||||
|
|
||||||
|
@ -245,7 +329,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||||
|
azure_ad_token
|
||||||
|
)
|
||||||
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
|
|
|
@ -207,3 +207,39 @@ def test_openai_extra_headers():
|
||||||
assert optional_params["max_tokens"] == 10
|
assert optional_params["max_tokens"] == 10
|
||||||
assert optional_params["temperature"] == 0.2
|
assert optional_params["temperature"] == 0.2
|
||||||
assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"}
|
assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_version",
|
||||||
|
[
|
||||||
|
"2024-02-01",
|
||||||
|
"2024-07-01", # potential future version with tool_choice="required" supported
|
||||||
|
"2023-07-01-preview",
|
||||||
|
"2024-03-01-preview",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_azure_tool_choice(api_version):
|
||||||
|
"""
|
||||||
|
Test azure tool choice on older + new version
|
||||||
|
"""
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="chatgpt-v-2",
|
||||||
|
user="John",
|
||||||
|
custom_llm_provider="azure",
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.2,
|
||||||
|
extra_headers={"AI-Resource Group": "ishaan-resource"},
|
||||||
|
tool_choice="required",
|
||||||
|
api_version=api_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{optional_params}")
|
||||||
|
if api_version == "2024-07-01":
|
||||||
|
assert optional_params["tool_choice"] == "required"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
"tool_choice" not in optional_params
|
||||||
|
), "tool_choice={} for api version={}".format(
|
||||||
|
optional_params["tool_choice"], api_version
|
||||||
|
)
|
||||||
|
|
|
@ -6045,6 +6045,22 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider="azure"
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
api_version = (
|
||||||
|
passed_params.get("api_version", None)
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
api_version=api_version, # type: ignore
|
||||||
|
)
|
||||||
else: # assume passing in params for azure openai
|
else: # assume passing in params for azure openai
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider="azure"
|
model=model, custom_llm_provider="azure"
|
||||||
|
@ -6481,29 +6497,7 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
return [
|
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||||
"functions",
|
|
||||||
"function_call",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"n",
|
|
||||||
"stream",
|
|
||||||
"stream_options",
|
|
||||||
"stop",
|
|
||||||
"max_tokens",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"user",
|
|
||||||
"response_format",
|
|
||||||
"seed",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"max_retries",
|
|
||||||
"logprobs",
|
|
||||||
"top_logprobs",
|
|
||||||
"extra_headers",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "openrouter":
|
elif custom_llm_provider == "openrouter":
|
||||||
return [
|
return [
|
||||||
"functions",
|
"functions",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue