From cb55069a70d74ec01aa8dfaa5375be86c882cd2f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 18 Mar 2025 10:11:54 -0700 Subject: [PATCH] _init_azure_client_for_cloudflare_ai_gateway --- litellm/llms/azure/azure.py | 43 +++++++++--------------------- litellm/llms/azure/common_utils.py | 42 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 6613ce57ba..6a16b50c31 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -238,37 +238,18 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ### CHECK IF CLOUDFLARE AI GATEWAY ### ### if so - set the model as part of the base url if "gateway.ai.cloudflare.com" in api_base: - ## build base url - assume api base includes resource name - if client is None: - if not api_base.endswith("/"): - api_base += "/" - api_base += f"{model}" - - azure_client_params = { - "api_version": api_version, - "base_url": f"{api_base}", - "http_client": litellm.client_session, - "max_retries": max_retries, - "timeout": timeout, - } - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc( - azure_ad_token - ) - - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = ( - azure_ad_token_provider - ) - - if acompletion is True: - client = AsyncAzureOpenAI(**azure_client_params) - else: - client = AzureOpenAI(**azure_client_params) + client = self._init_azure_client_for_cloudflare_ai_gateway( + api_base=api_base, + model=model, + api_version=api_version, + max_retries=max_retries, + timeout=timeout, + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + acompletion=acompletion, + client=client, + ) data = {"model": None, "messages": messages, **optional_params} else: diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 24eb758653..ac500fc2d6 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -357,3 +357,45 @@ class BaseAzureLLM: ) return azure_client_params + + def _init_azure_client_for_cloudflare_ai_gateway( + self, + api_base: str, + model: str, + api_version: str, + max_retries: int, + timeout: Union[float, httpx.Timeout], + api_key: Optional[str], + azure_ad_token: Optional[str], + azure_ad_token_provider: Optional[Callable[[], str]], + acompletion: bool, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ) -> Union[AzureOpenAI, AsyncAzureOpenAI]: + ## build base url - assume api base includes resource name + if client is None: + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + + azure_client_params = { + "api_version": api_version, + "base_url": f"{api_base}", + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + + azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + + if acompletion is True: + client = AsyncAzureOpenAI(**azure_client_params) + else: + client = AzureOpenAI(**azure_client_params) + return client