(feat) litellm - add _async_failure_callback

This commit is contained in:
ishaan-jaff 2023-12-06 14:41:40 -08:00
parent 3b17fd3821
commit b3f039627e
3 changed files with 63 additions and 0 deletions

View file

@ -10,6 +10,7 @@ success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
email: Optional[ email: Optional[

View file

@ -81,3 +81,23 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass

View file

@ -1113,6 +1113,36 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {traceback.format_exc()}"
) )
pass pass
async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
result = {} # result sent to all loggers, init this to None incase it's not created
for callback in litellm._async_failure_callback:
try:
if callable(callback): # custom logger functions
await customLogger.async_log_failure_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback
)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def exception_logging( def exception_logging(
@ -1236,6 +1266,17 @@ def client(original_function):
# Pop the async items from success_callback in reverse order to avoid index issues # Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items): for index in reversed(removed_async_items):
litellm.success_callback.pop(index) litellm.success_callback.pop(index)
if len(litellm.failure_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.failure_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_failure_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
if add_breadcrumb: if add_breadcrumb:
add_breadcrumb( add_breadcrumb(
category="litellm.llm_call", category="litellm.llm_call",
@ -1513,6 +1554,7 @@ def client(original_function):
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if logging_obj: if logging_obj:
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
raise e raise e
is_coroutine = inspect.iscoroutinefunction(original_function) is_coroutine = inspect.iscoroutinefunction(original_function)