test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs

This commit is contained in:
Krrish Dholakia 2023-12-11 11:38:28 -08:00
parent 1d2f5ce975
commit ea89a8a938
8 changed files with 501 additions and 122 deletions

View file

@ -801,9 +801,6 @@ class Logging:
end_time = datetime.datetime.now()
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
if isinstance(result, OpenAIObject):
result = result.model_dump()
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
@ -857,9 +854,6 @@ class Logging:
call_type = self.call_type,
stream = self.stream,
)
if callback == "api_manager":
print_verbose("reaches api manager for updating model cost")
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
if callback == "promptlayer":
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
@ -994,7 +988,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if isinstance(callback, CustomLogger): # custom logger class
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
print_verbose(f"success callbacks: Running Custom Logger Class")
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1044,7 +1038,6 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None
if self.stream:
@ -1081,6 +1074,13 @@ class Logging:
start_time=start_time,
end_time=end_time,
)
else:
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time
)
else:
await callback.async_log_success_event(
kwargs=self.model_call_details,
@ -1103,24 +1103,29 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
self.model_call_details.setdefault("original_response", None)
return start_time, end_time
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Failure Call"
)
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in litellm.failure_callback:
try:
@ -1212,16 +1217,8 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
result = {} # result sent to all loggers, init this to None incase it's not created
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in litellm._async_failure_callback:
try:
if isinstance(callback, CustomLogger): # custom logger class
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
return model_cost
def get_litellm_params(
return_async=False,
api_key=None,
force_timeout=600,
azure=False,
@ -2082,7 +2078,6 @@ def get_litellm_params(
):
litellm_params = {
"acompletion": acompletion,
"return_async": return_async,
"api_key": api_key,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
if self.logging_obj:
# Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream)))
def __iter__(self):
return self
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def logging(self, text):
if self.logging_obj:
self.logging_obj.post_call(text)
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
hold = False
if finish_reason:
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
print_verbose(f"model_response: {model_response}")
return model_response
else:
return
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
# enter this branch when no content has been passed in response
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
return model_response
else:
return
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
except Exception as e:
traceback_exception = traceback.format_exc()
e.message = str(e)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string)
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b'':
response = self.chunk_creator(chunk=chunk)
if response is not None:
return response
if response is None:
continue
## LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
return response
except StopIteration:
raise # Re-raise StopIteration
except Exception as e:
# Handle other exceptions if needed
traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise e
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
return processed_chunk
except Exception as e:
traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
raise StopAsyncIteration
class TextCompletionStreamWrapper: