diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index f6c8ae9a4..e75886636 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -84,7 +84,7 @@ class AzureChatCompletion(BaseLLM): } if api_key is not None: headers["api-key"] = api_key - if azure_ad_token is not None: + elif azure_ad_token is not None: headers["Authorization"] = f"Bearer {azure_ad_token}" return headers @@ -139,7 +139,20 @@ class AzureChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + azure_client = AzureOpenAI(**azure_client_params) response = azure_client.chat.completions.create(**data) # type: ignore return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except AzureOpenAIError as e: @@ -162,7 +175,20 @@ class AzureChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") - azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + azure_client = AsyncAzureOpenAI(**azure_client_params) response = await azure_client.chat.completions.create(**data) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: @@ -186,7 +212,20 @@ class AzureChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + azure_client = AzureOpenAI(**azure_client_params) response = azure_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) for transformed_chunk in streamwrapper: @@ -201,7 +240,20 @@ class AzureChatCompletion(BaseLLM): model: str, timeout: Any, azure_ad_token: Optional[str]=None): - azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": data.pop("max_retries", 2), + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + azure_client = AsyncAzureOpenAI(**azure_client_params) response = await azure_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) async for transformed_chunk in streamwrapper: @@ -213,6 +265,7 @@ class AzureChatCompletion(BaseLLM): api_key: str, api_base: str, api_version: str, + timeout: float, logging_obj=None, model_response=None, optional_params=None, @@ -230,8 +283,21 @@ class AzureChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, max_retries=max_retries) - + + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + azure_client = AzureOpenAI(**azure_client_params) ## LOGGING logging_obj.pre_call( input=input,