forked from phoenix/litellm-mirror
test(test_custom_callback_input.py): embedding callback tests for azure, openai, bedrock
This commit is contained in:
parent
8ee77d7b82
commit
ad39afc0ad
6 changed files with 185 additions and 49 deletions
|
@ -989,7 +989,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||
if self.stream and complete_streaming_response is None:
|
||||
callback.log_stream_event(
|
||||
|
@ -1192,7 +1192,7 @@ class Logging:
|
|||
print_verbose=print_verbose,
|
||||
callback_func=callback
|
||||
)
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||
callback.log_failure_event(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1641,7 +1641,7 @@ def client(original_function):
|
|||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
|
||||
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||
# RETURN RESULT
|
||||
|
@ -1678,7 +1678,7 @@ def client(original_function):
|
|||
end_time = datetime.datetime.now()
|
||||
if logging_obj:
|
||||
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
||||
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
|
||||
await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)
|
||||
raise e
|
||||
|
||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue