mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(azure.py): support dynamic api versions
Closes https://github.com/BerriAI/litellm/issues/5228
This commit is contained in:
parent
417547b6f9
commit
49416e121c
5 changed files with 176 additions and 32 deletions
|
@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
|||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def _check_dynamic_azure_params(
|
||||
azure_client_params: dict,
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if user passed in client params != initialized azure client
|
||||
|
||||
Currently only implemented for api version
|
||||
"""
|
||||
if azure_client is None:
|
||||
return True
|
||||
|
||||
dynamic_params = ["api_version"]
|
||||
for k, v in azure_client_params.items():
|
||||
if k in dynamic_params and k == "api_version":
|
||||
if v is not None and v != azure_client._custom_query["api-version"]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
return azure_client
|
||||
|
||||
def make_sync_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AzureOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||
- call chat.completions.create by default
|
||||
"""
|
||||
try:
|
||||
raw_response = azure_client.chat.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def make_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
|
@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
dynamic_params: bool,
|
||||
print_verbose: Callable,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
|
@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
dynamic_params=dynamic_params,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version=api_version,
|
||||
model=model,
|
||||
azure_ad_token=azure_ad_token,
|
||||
dynamic_params=dynamic_params,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
dynamic_params=dynamic_params,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
|
@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
"api-version", api_version
|
||||
)
|
||||
|
||||
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
|
||||
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||
azure_client=azure_client, data=data, timeout=timeout
|
||||
)
|
||||
stringified_response = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
dynamic_params: bool,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
|
@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
if client is None:
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
dynamic_params: bool,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
|
@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = azure_client.chat.completions.create(**data, timeout=timeout)
|
||||
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||
azure_client=azure_client, data=data, timeout=timeout
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
dynamic_params: bool,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
|
@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue