diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 7a7a38b79d..0881a7d6d4 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -118,28 +118,29 @@ class AzureChatCompletion(BaseLLM): ### CHECK IF CLOUDFLARE AI GATEWAY ### ### if so - set the model as part of the base url - if "gateway.ai.cloudflare.com" in api_base and client is None: + if "gateway.ai.cloudflare.com" in api_base: ## build base url - assume api base includes resource name - if not api_base.endswith("/"): - api_base += "/" - api_base += f"{model}" - - azure_client_params = { - "api_version": api_version, - "base_url": f"{api_base}", - "http_client": litellm.client_session, - "max_retries": max_retries, - "timeout": timeout - } - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + + azure_client_params = { + "api_version": api_version, + "base_url": f"{api_base}", + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token - if acompletion is True: - client = AsyncAzureOpenAI(**azure_client_params) - else: - client = AzureOpenAI(**azure_client_params) + if acompletion is True: + client = AsyncAzureOpenAI(**azure_client_params) + else: + client = AzureOpenAI(**azure_client_params) data = { "model": None, @@ -162,7 +163,7 @@ class AzureChatCompletion(BaseLLM): "azure_ad_token": azure_ad_token }, "api_version": api_version, - "api_base": api_base, + "api_base": client.base_url, "complete_input_dict": data, }, ) diff --git a/litellm/router.py b/litellm/router.py index 00b3a8d125..24f356e085 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -856,16 +856,32 @@ class Router: if "azure" in model_name: if api_version is None: api_version = "2023-07-01-preview" - model["async_client"] = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version - ) - model["client"] = openai.AzureOpenAI( - api_key=api_key, - azure_endpoint=api_base, - api_version=api_version - ) + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + model["async_client"] = openai.AsyncAzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version + ) + model["client"] = openai.AzureOpenAI( + api_key=api_key, + base_url=api_base, + api_version=api_version + ) + else: + model["async_client"] = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version + ) + model["client"] = openai.AzureOpenAI( + api_key=api_key, + azure_endpoint=api_base, + api_version=api_version + ) else: model["async_client"] = openai.AsyncOpenAI( api_key=api_key, diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 05446f9f46..7a722095e6 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -65,7 +65,7 @@ def test_async_response_azure(): user_message = "What do you know?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5) + response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY")) print(f"response: {response}") except litellm.Timeout as e: pass @@ -76,6 +76,7 @@ def test_async_response_azure(): # test_async_response_azure() + def test_async_anyscale_response(): import asyncio litellm.set_verbose = True