feat(utils.py): add async success callbacks for custom functions

This commit is contained in:
Krrish Dholakia 2023-12-04 16:36:21 -08:00
parent b90fcbdac4
commit e0ccb281d8
8 changed files with 232 additions and 138 deletions

View file

@ -741,13 +741,9 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
try:
def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None):
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
@ -776,6 +772,18 @@ class Logging:
float_diff = float(time_diff)
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
return start_time, end_time, result, complete_streaming_response
except:
pass
def success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
print_verbose(
f"Logging Details LiteLLM-Success Call"
)
try:
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"success callbacks: {litellm.success_callback}")
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -969,6 +977,29 @@ class Logging:
)
pass
async def async_success_handler(self, result=None, start_time=None, end_time=None, **kwargs):
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
start_time, end_time, result, complete_streaming_response = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result)
print_verbose(f"success callbacks: {litellm.success_callback}")
for callback in litellm._async_success_callback:
try:
if callable(callback): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback
)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Failure Call"
@ -1185,6 +1216,17 @@ def client(original_function):
callback_list=callback_list,
function_id=function_id
)
## ASYNC CALLBACKS
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback):
if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
@ -1373,7 +1415,6 @@ def client(original_function):
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
@ -1426,8 +1467,8 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
# LOG SUCCESS - handle streaming success logging in the _next_ object
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# RETURN RESULT
if isinstance(result, ModelResponse):
@ -1465,7 +1506,6 @@ def client(original_function):
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
raise e
# Use httpx to determine if the original function is a coroutine
is_coroutine = inspect.iscoroutinefunction(original_function)
# Return the appropriate wrapper based on the original function type
@ -5370,6 +5410,8 @@ class CustomStreamWrapper:
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue
## LOGGING
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls