(feat) set api_version for Azure

This commit is contained in:
ishaan-jaff 2024-03-08 13:38:29 -08:00
parent 079ba68197
commit f2c6bb9d3e

View file

@ -270,6 +270,14 @@ class AzureChatCompletion(BaseLLM):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault(
"api-version", api_version
)
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
stringified_response = response.model_dump()
## LOGGING
@ -408,6 +416,10 @@ class AzureChatCompletion(BaseLLM):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -461,6 +473,11 @@ class AzureChatCompletion(BaseLLM):
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["messages"],