diff --git a/litellm/llms/azure/completion/handler.py b/litellm/llms/azure/completion/handler.py index 4ec5c435da..91a00ebc2f 100644 --- a/litellm/llms/azure/completion/handler.py +++ b/litellm/llms/azure/completion/handler.py @@ -76,33 +76,25 @@ class AzureTextCompletion(BaseAzureLLM): model_name=model, api_version=api_version, api_base=api_base, + is_async=False, ) ### CHECK IF CLOUDFLARE AI GATEWAY ### ### if so - set the model as part of the base url if "gateway.ai.cloudflare.com" in api_base: ## build base url - assume api base includes resource name - if client is None: - if not api_base.endswith("/"): - api_base += "/" - api_base += f"{model}" - - azure_client_params = { - "api_version": api_version, - "base_url": f"{api_base}", - "http_client": litellm.client_session, - "max_retries": max_retries, - "timeout": timeout, - } - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - azure_client_params["azure_ad_token"] = azure_ad_token - - if acompletion is True: - client = AsyncAzureOpenAI(**azure_client_params) - else: - client = AzureOpenAI(**azure_client_params) + client = self._init_azure_client_for_cloudflare_ai_gateway( + api_key=api_key, + api_version=api_version, + api_base=api_base, + model=model, + client=client, + max_retries=max_retries, + timeout=timeout, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + acompletion=acompletion, + ) data = {"model": None, "prompt": prompt, **optional_params} else: @@ -174,17 +166,21 @@ class AzureTextCompletion(BaseAzureLLM): status_code=422, message="max retries must be an int" ) # init AzureOpenAI Client - if client is None: - azure_client = AzureOpenAI(**azure_client_params) - else: - azure_client = client - if api_version is not None and isinstance( - azure_client._custom_query, dict - ): - # set api_version to version passed by user - azure_client._custom_query.setdefault( - "api-version", api_version - ) + azure_client = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + litellm_params=litellm_params, + _is_async=False, + model=model, + ) + + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) raw_response = azure_client.completions.with_raw_response.create( **data, timeout=timeout @@ -234,20 +230,27 @@ class AzureTextCompletion(BaseAzureLLM): azure_ad_token: Optional[str] = None, client=None, # this is the AsyncAzureOpenAI azure_client_params: dict = {}, + litellm_params: dict = {}, ): response = None try: # init AzureOpenAI Client # setting Azure client - if client is None: - azure_client = AsyncAzureOpenAI(**azure_client_params) - else: - azure_client = client - if api_version is not None and isinstance( - azure_client._custom_query, dict - ): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AsyncAzureOpenAI", + ) + ## LOGGING logging_obj.pre_call( input=data["prompt"], @@ -291,6 +294,7 @@ class AzureTextCompletion(BaseAzureLLM): azure_ad_token: Optional[str] = None, client=None, azure_client_params: dict = {}, + litellm_params: dict = {}, ): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): @@ -298,13 +302,21 @@ class AzureTextCompletion(BaseAzureLLM): status_code=422, message="max retries must be an int" ) # init AzureOpenAI Client - if client is None: - azure_client = AzureOpenAI(**azure_client_params) - else: - azure_client = client - if api_version is not None and isinstance(azure_client._custom_query, dict): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=False, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + ## LOGGING logging_obj.pre_call( input=data["prompt"], @@ -340,18 +352,24 @@ class AzureTextCompletion(BaseAzureLLM): azure_ad_token: Optional[str] = None, client=None, azure_client_params: dict = {}, + litellm_params: dict = {}, ): try: # init AzureOpenAI Client - if client is None: - azure_client = AsyncAzureOpenAI(**azure_client_params) - else: - azure_client = client - if api_version is not None and isinstance( - azure_client._custom_query, dict - ): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AsyncAzureOpenAI", + ) ## LOGGING logging_obj.pre_call( input=data["prompt"],