forked from phoenix/litellm-mirror
test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs
This commit is contained in:
parent
1d2f5ce975
commit
ea89a8a938
8 changed files with 501 additions and 122 deletions
|
@ -801,9 +801,6 @@ class Logging:
|
|||
end_time = datetime.datetime.now()
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
|
||||
if isinstance(result, OpenAIObject):
|
||||
result = result.model_dump()
|
||||
|
||||
if litellm.max_budget and self.stream:
|
||||
time_diff = (end_time - start_time).total_seconds()
|
||||
|
@ -857,9 +854,6 @@ class Logging:
|
|||
call_type = self.call_type,
|
||||
stream = self.stream,
|
||||
)
|
||||
if callback == "api_manager":
|
||||
print_verbose("reaches api manager for updating model cost")
|
||||
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
|
||||
if callback == "promptlayer":
|
||||
print_verbose("reaches promptlayer for logging!")
|
||||
promptLayerLogger.log_event(
|
||||
|
@ -994,7 +988,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||
if self.stream and complete_streaming_response is None:
|
||||
callback.log_stream_event(
|
||||
|
@ -1044,7 +1038,6 @@ class Logging:
|
|||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response = None
|
||||
if self.stream:
|
||||
|
@ -1081,6 +1074,13 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
else:
|
||||
await callback.async_log_success_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -1103,24 +1103,29 @@ class Logging:
|
|||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details.setdefault("original_response", None)
|
||||
return start_time, end_time
|
||||
|
||||
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
print_verbose(
|
||||
f"Logging Details LiteLLM-Failure Call"
|
||||
)
|
||||
try:
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in litellm.failure_callback:
|
||||
try:
|
||||
|
@ -1212,16 +1217,8 @@ class Logging:
|
|||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
result = {} # result sent to all loggers, init this to None incase it's not created
|
||||
|
||||
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in litellm._async_failure_callback:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
|
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
|
|||
return model_cost
|
||||
|
||||
def get_litellm_params(
|
||||
return_async=False,
|
||||
api_key=None,
|
||||
force_timeout=600,
|
||||
azure=False,
|
||||
|
@ -2082,7 +2078,6 @@ def get_litellm_params(
|
|||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
"return_async": return_async,
|
||||
"api_key": api_key,
|
||||
"force_timeout": force_timeout,
|
||||
"logger_fn": logger_fn,
|
||||
|
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
|
|||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
self.holding_chunk = ""
|
||||
self.complete_response = ""
|
||||
if self.logging_obj:
|
||||
# Log the type of the received item
|
||||
self.logging_obj.post_call(str(type(completion_stream)))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def logging(self, text):
|
||||
if self.logging_obj:
|
||||
self.logging_obj.post_call(text)
|
||||
|
||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||
hold = False
|
||||
if finish_reason:
|
||||
|
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
|
|||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
print_verbose(f"model_response: {model_response}")
|
||||
return model_response
|
||||
else:
|
||||
return
|
||||
elif model_response.choices[0].finish_reason:
|
||||
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
return model_response
|
||||
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
|
||||
# enter this branch when no content has been passed in response
|
||||
|
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
|
|||
if self.sent_first_chunk == False:
|
||||
model_response.choices[0].delta["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
|
||||
return model_response
|
||||
else:
|
||||
return
|
||||
|
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
e.message = str(e)
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
||||
|
||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||
|
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
|
|||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b'':
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
if response is not None:
|
||||
return response
|
||||
if response is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
|
||||
return response
|
||||
except StopIteration:
|
||||
raise # Re-raise StopIteration
|
||||
except Exception as e:
|
||||
# Handle other exceptions if needed
|
||||
traceback_exception = traceback.format_exc()
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
|
|||
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
# Handle any exceptions that might occur during streaming
|
||||
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
|
||||
raise StopAsyncIteration
|
||||
|
||||
class TextCompletionStreamWrapper:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue