fix(azure_dall_e_2.py): handle azure not returning a 'retry-after' param

This commit is contained in:
Krrish Dholakia 2024-01-22 12:02:02 -08:00
parent 11bd726942
commit 4e89be0e19
2 changed files with 17 additions and 9 deletions

View file

@ -629,12 +629,23 @@ class AzureChatCompletion(BaseLLM):
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
) )
openai_aclient = AsyncAzureOpenAI( azure_client = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params http_client=client_session, **azure_client_params
) )
else: else:
openai_aclient = client azure_client = client
response = await openai_aclient.images.generate(**data, timeout=timeout) ## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"api_key": azure_client.api_key},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.images.generate(**data, timeout=timeout)
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -719,7 +730,7 @@ class AzureChatCompletion(BaseLLM):
input=prompt, input=prompt,
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={ additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "headers": {"api_key": azure_client.api_key},
"api_base": azure_client._base_url._uri_reference, "api_base": azure_client._base_url._uri_reference,
"acompletion": False, "acompletion": False,
"complete_input_dict": data, "complete_input_dict": data,

View file

@ -43,7 +43,7 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
request=request, request=request,
) )
time.sleep(int(response.headers.get("retry-after")) or 10) await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await super().handle_async_request(request) response = await super().handle_async_request(request)
await response.aread() await response.aread()
@ -95,7 +95,6 @@ class CustomHTTPTransport(httpx.HTTPTransport):
request.method = "GET" request.method = "GET"
response = super().handle_request(request) response = super().handle_request(request)
response.read() response.read()
timeout_secs: int = 120 timeout_secs: int = 120
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
@ -112,11 +111,9 @@ class CustomHTTPTransport(httpx.HTTPTransport):
content=json.dumps(timeout).encode("utf-8"), content=json.dumps(timeout).encode("utf-8"),
request=request, request=request,
) )
time.sleep(int(response.headers.get("retry-after", None) or 10))
time.sleep(int(response.headers.get("retry-after")) or 10)
response = super().handle_request(request) response = super().handle_request(request)
response.read() response.read()
if response.json()["status"] == "failed": if response.json()["status"] == "failed":
error_data = response.json() error_data = response.json()
return httpx.Response( return httpx.Response(