test(test_custom_callback_input.py): embedding callback tests for azure, openai, bedrock

This commit is contained in:
Krrish Dholakia 2023-12-11 15:32:34 -08:00
parent 8ee77d7b82
commit ad39afc0ad
6 changed files with 185 additions and 49 deletions

View file

@ -989,7 +989,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
print_verbose(f"success callbacks: Running Custom Logger Class")
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1192,7 +1192,7 @@ class Logging:
print_verbose=print_verbose,
callback_func=callback
)
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
callback.log_failure_event(
start_time=start_time,
end_time=end_time,
@ -1641,7 +1641,7 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# RETURN RESULT
@ -1678,7 +1678,7 @@ def client(original_function):
end_time = datetime.datetime.now()
if logging_obj:
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)
raise e
is_coroutine = inspect.iscoroutinefunction(original_function)