fix(utils.py): correctly instrument passing through api version in optional param check

This commit is contained in:
Krrish Dholakia 2024-06-01 19:31:34 -07:00
parent 373a41ca6d
commit 9ef83126d7
4 changed files with 6 additions and 4 deletions

View file

@ -147,7 +147,6 @@ class AzureOpenAIConfig:
api_version_year = api_version_times[0] api_version_year = api_version_times[0]
api_version_month = api_version_times[1] api_version_month = api_version_times[1]
api_version_day = api_version_times[2] api_version_day = api_version_times[2]
args = locals()
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "tool_choice": if param == "tool_choice":
""" """

View file

@ -838,6 +838,7 @@ def completion(
logprobs=logprobs, logprobs=logprobs,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
extra_headers=extra_headers, extra_headers=extra_headers,
api_version=api_version,
**non_default_params, **non_default_params,
) )

View file

@ -2828,6 +2828,7 @@ async def test_azure_astreaming_and_function_calling():
password=os.environ["REDIS_PASSWORD"], password=os.environ["REDIS_PASSWORD"],
) )
try: try:
litellm.set_verbose = True
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/gpt-4-nov-release", model="azure/gpt-4-nov-release",
tools=tools, tools=tools,

View file

@ -5192,6 +5192,7 @@ def get_optional_params(
logprobs=None, logprobs=None,
top_logprobs=None, top_logprobs=None,
extra_headers=None, extra_headers=None,
api_version=None,
**kwargs, **kwargs,
): ):
# retrieve all parameters passed to the function # retrieve all parameters passed to the function
@ -5262,6 +5263,7 @@ def get_optional_params(
"logprobs": None, "logprobs": None,
"top_logprobs": None, "top_logprobs": None,
"extra_headers": None, "extra_headers": None,
"api_version": None,
} }
# filter out those parameters that were passed with non-default values # filter out those parameters that were passed with non-default values
non_default_params = { non_default_params = {
@ -5270,6 +5272,7 @@ def get_optional_params(
if ( if (
k != "model" k != "model"
and k != "custom_llm_provider" and k != "custom_llm_provider"
and k != "api_version"
and k in default_params and k in default_params
and v != default_params[k] and v != default_params[k]
) )
@ -6051,9 +6054,7 @@ def get_optional_params(
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
api_version = ( api_version = (
passed_params.get("api_version", None) api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) )
optional_params = litellm.AzureOpenAIConfig().map_openai_params( optional_params = litellm.AzureOpenAIConfig().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,