feat(azure.py): support dynamic api versions

Closes https://github.com/BerriAI/litellm/issues/5228
This commit is contained in:
Krrish Dholakia 2024-08-19 12:17:43 -07:00
parent 417547b6f9
commit 49416e121c
5 changed files with 176 additions and 32 deletions

View file

@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
return azure_ad_token_access_token
def _check_dynamic_azure_params(
azure_client_params: dict,
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
) -> bool:
"""
Returns True if user passed in client params != initialized azure client
Currently only implemented for api version
"""
if azure_client is None:
return True
dynamic_params = ["api_version"]
for k, v in azure_client_params.items():
if k in dynamic_params and k == "api_version":
if v is not None and v != azure_client._custom_query["api-version"]:
return True
return False
class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM):
return azure_client
def make_sync_azure_openai_chat_completion_request(
self,
azure_client: AzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
async def make_azure_openai_chat_completion_request(
self,
azure_client: AsyncAzureOpenAI,
@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str,
api_type: str,
azure_ad_token: str,
dynamic_params: bool,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
dynamic_params=dynamic_params,
data=data,
model=model,
api_key=api_key,
@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
dynamic_params=dynamic_params,
timeout=timeout,
client=client,
logging_obj=logging_obj,
@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
dynamic_params=dynamic_params,
data=data,
model=model,
api_key=api_key,
@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM):
"api-version", api_version
)
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str,
data: dict,
timeout: Any,
dynamic_params: bool,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client
if client is None:
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str,
api_key: str,
api_version: str,
dynamic_params: bool,
data: dict,
model: str,
timeout: Any,
@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["messages"],
@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = azure_client.chat.completions.create(**data, timeout=timeout)
headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str,
api_key: str,
api_version: str,
dynamic_params: bool,
data: dict,
model: str,
timeout: Any,
@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING
logging_obj.pre_call(
input=data["messages"],