fix(main.py): keep client consistent across calls + exponential backoff retry on ratelimit errors

This commit is contained in:
Krrish Dholakia 2023-11-14 16:25:36 -08:00
parent 5963d9d283
commit a7222f257c
9 changed files with 239 additions and 131 deletions

View file

@ -456,6 +456,7 @@ from enum import Enum
class CallTypes(Enum):
embedding = 'embedding'
completion = 'completion'
acompletion = 'acompletion'
# Logging function -> log the exact model details + what's being sent | Non-Blocking
class Logging:
@ -984,7 +985,7 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
import inspect
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1036,7 +1037,7 @@ def client(original_function):
# INIT LOGGER - for user-specified integrations
model = args[0] if len(args) > 0 else kwargs["model"]
call_type = original_function.__name__
if call_type == CallTypes.completion.value:
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
@ -1183,7 +1184,107 @@ def client(original_function):
): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
return wrapper
async def wrapper_async(*args, **kwargs):
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
try:
if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget:
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
# [OPTIONAL] CHECK CACHE
print_verbose(f"litellm.cache: {litellm.cache}")
print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}")
# if caching is false, don't run this
if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function
# checking cache
if (litellm.cache != None):
print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else:
return cached_result
# MODEL CALL
result = original_function(*args, **kwargs)
end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True:
if "complete_response" in kwargs and kwargs["complete_response"] == True:
chunks = []
for idx, chunk in enumerate(result):
chunks.append(chunk)
return litellm.stream_chunk_builder(chunks)
else:
return result
result = await result
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
logging_obj.success_handler(result, start_time, end_time)
# RETURN RESULT
return result
except Exception as e:
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value:
num_retries = (
kwargs.get("num_retries", None)
or litellm.num_retries
or None
)
litellm.num_retries = None # set retries to None to prevent infinite loops
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {})
if num_retries:
kwargs["num_retries"] = num_retries
kwargs["original_function"] = original_function
if (isinstance(e, openai.RateLimitError)): # rate limiting specific error
kwargs["retry_strategy"] = "exponential_backoff_retry"
elif (isinstance(e, openai.APIError)): # generic api error
kwargs["retry_strategy"] = "constant_retry"
return litellm.completion_with_retries(*args, **kwargs)
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
if logging_obj:
threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start()
raise e
# Use httpx to determine if the original function is a coroutine
is_coroutine = inspect.iscoroutinefunction(original_function)
# Return the appropriate wrapper based on the original function type
if is_coroutine:
return wrapper_async
else:
return wrapper
####### USAGE CALCULATOR ################
@ -3116,31 +3217,13 @@ def exception_type(
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa
print() # noqa
try:
if isinstance(original_exception, OriginalError):
# Handle the OpenAIError
exception_mapping_worked = True
if custom_llm_provider == "openrouter":
if original_exception.http_status == 413:
raise BadRequestError(
message=str(original_exception),
model=model,
llm_provider="openrouter"
)
original_exception.llm_provider = "openrouter"
if "This model's maximum context length is" in original_exception._message:
raise ContextWindowExceededError(
message=str(original_exception),
model=model,
llm_provider=original_exception.llm_provider
)
raise original_exception
elif model:
if model:
error_str = str(original_exception)
if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__
else:
exception_type = ""
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai":
if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai":
if "This model's maximum context length is" in error_str or "Request too large" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
@ -3191,6 +3274,14 @@ def exception_type(
llm_provider="openai",
response=original_exception.response
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
response=original_exception.response
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
@ -3968,49 +4059,6 @@ def exception_type(
model=model,
request=original_exception.request
)
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
if hasattr(original_exception, "status_code"):
exception_mapping_worked = True
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"CustomOpenAIException - {original_exception.message}",
llm_provider="custom_openai",
model=model
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
request=original_exception.request
)
if original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
response=original_exception.response
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"CustomOpenAIException - {original_exception.message}",
model=model,
llm_provider="custom_openai",
response=original_exception.response
)
else:
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"CustomOpenAIException - {original_exception.message}",
llm_provider="custom_openai",
model=model,
request=original_exception.request
)
if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk
exception_mapping_worked = True
raise BadRequestError(