fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent c4e53aa77b
commit 0ab6b2451d
8 changed files with 142 additions and 81 deletions

View file

@ -18,7 +18,7 @@ import tiktoken
import uuid
import aiohttp
import logging
import asyncio, httpx
import asyncio, httpx, inspect
import copy
from tokenizers import Tokenizer
from dataclasses import (
@ -1047,7 +1047,6 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
import inspect
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1333,13 +1332,13 @@ def client(original_function):
kwargs["retry_strategy"] = "exponential_backoff_retry"
elif (isinstance(e, openai.APIError)): # generic api error
kwargs["retry_strategy"] = "constant_retry"
return litellm.completion_with_retries(*args, **kwargs)
return await litellm.acompletion_with_retries(*args, **kwargs)
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
return await original_function(*args, **kwargs)
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
@ -3370,7 +3369,7 @@ def exception_type(
else:
exception_type = ""
if "Request Timeout Error" in error_str:
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out",
@ -3411,7 +3410,6 @@ def exception_type(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
request=original_exception.request
)
if original_exception.status_code == 422:
exception_mapping_worked = True
@ -4229,7 +4227,7 @@ def exception_type(
llm_provider=custom_llm_provider,
response=original_exception.response
)
else: # ensure generic errors always return APIConnectionError
else: # ensure generic errors always return APIConnectionError=
exception_mapping_worked = True
if hasattr(original_exception, "request"):
raise APIConnectionError(