From 4e89be0e195e38d30ef7078a35d29b06f7027adc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 22 Jan 2024 12:02:02 -0800 Subject: [PATCH] fix(azure_dall_e_2.py): handle azure not returning a 'retry-after' param --- litellm/llms/azure.py | 19 +++++++++++++++---- litellm/llms/custom_httpx/azure_dall_e_2.py | 7 ++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 0eb70c86f7..f20a2e9397 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -629,12 +629,23 @@ class AzureChatCompletion(BaseLLM): client_session = litellm.aclient_session or httpx.AsyncClient( transport=AsyncCustomHTTPTransport(), ) - openai_aclient = AsyncAzureOpenAI( + azure_client = AsyncAzureOpenAI( http_client=client_session, **azure_client_params ) else: - openai_aclient = client - response = await openai_aclient.images.generate(**data, timeout=timeout) + azure_client = client + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"api_key": azure_client.api_key}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = await azure_client.images.generate(**data, timeout=timeout) stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -719,7 +730,7 @@ class AzureChatCompletion(BaseLLM): input=prompt, api_key=azure_client.api_key, additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "headers": {"api_key": azure_client.api_key}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data, diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index a62e1d666d..f361ede5bf 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -43,7 +43,7 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): request=request, ) - time.sleep(int(response.headers.get("retry-after")) or 10) + await asyncio.sleep(int(response.headers.get("retry-after") or 10)) response = await super().handle_async_request(request) await response.aread() @@ -95,7 +95,6 @@ class CustomHTTPTransport(httpx.HTTPTransport): request.method = "GET" response = super().handle_request(request) response.read() - timeout_secs: int = 120 start_time = time.time() while response.json()["status"] not in ["succeeded", "failed"]: @@ -112,11 +111,9 @@ class CustomHTTPTransport(httpx.HTTPTransport): content=json.dumps(timeout).encode("utf-8"), request=request, ) - - time.sleep(int(response.headers.get("retry-after")) or 10) + time.sleep(int(response.headers.get("retry-after", None) or 10)) response = super().handle_request(request) response.read() - if response.json()["status"] == "failed": error_data = response.json() return httpx.Response(