feat(utils.py): support dynamic success callbacks

This commit is contained in:
Krrish Dholakia 2024-02-01 19:42:01 -08:00
parent 1a1b929a4e
commit b79a6607b2

View file

@ -750,6 +750,8 @@ class Logging:
start_time,
litellm_call_id,
function_id,
dynamic_success_callbacks=None,
dynamic_async_success_callbacks=None,
):
if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes])
@ -770,6 +772,14 @@ class Logging:
self.streaming_chunks = [] # for generating complete stream response
self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {}
self.dynamic_input_callbacks = [] # callbacks set for just that call
self.dynamic_failure_callbacks = [] # callbacks set for just that call
self.dynamic_success_callbacks = (
dynamic_success_callbacks or []
) # callbacks set for just that call
self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks or []
) # callbacks set for just that call
def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params
@ -873,7 +883,8 @@ class Logging:
)
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
try:
if callback == "supabase":
print_verbose("reaches supabase for logging!")
@ -946,43 +957,6 @@ class Logging:
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
async def async_pre_call(
self, result=None, start_time=None, end_time=None, **kwargs
):
"""
 Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result
)
print_verbose(f"Async input callbacks: {litellm._async_input_callback}")
for callback in litellm._async_input_callback:
try:
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(f"Async input callbacks: CustomLogger")
asyncio.create_task(
callback.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
)
)
if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
asyncio.create_task(
customLogger.async_log_input_event(
model=self.model,
messages=self.messages,
kwargs=self.model_call_details,
print_verbose=print_verbose,
callback_func=callback,
)
)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def post_call(
self, original_response, input=None, api_key=None, additional_args={}
):
@ -1015,7 +989,9 @@ class Logging:
)
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
try:
if callback == "lite_debugger":
print_verbose("reaches litedebugger for post-call logging!")
@ -1164,8 +1140,8 @@ class Logging:
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm.success_callback:
callbacks = litellm.success_callback + self.dynamic_success_callbacks
for callback in callbacks:
try:
if callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
@ -1466,7 +1442,10 @@ class Logging:
)
self.model_call_details["response_cost"] = None
for callback in litellm._async_success_callback:
callbacks = (
litellm._async_success_callback + self.dynamic_async_success_callbacks
)
for callback in callbacks:
try:
if callback == "cache" and litellm.cache is not None:
# set_cache once complete streaming response is built
@ -1968,6 +1947,26 @@ def client(original_function):
# Pop the async items from failure_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = []
dynamic_async_success_callbacks = []
if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list
):
removed_async_items = []
for index, callback in enumerate(kwargs["success_callback"]):
if (
inspect.iscoroutinefunction(callback)
or callback == "dynamodb"
or callback == "s3"
):
dynamic_async_success_callbacks.append(callback)
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs["success_callback"]
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
@ -2029,6 +2028,8 @@ def client(original_function):
function_id=function_id,
call_type=call_type,
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
)
## check if metadata is passed in
litellm_params = {}