(feat) async callbacks with litellm.completion()

This commit is contained in:
ishaan-jaff 2023-12-07 18:09:57 -08:00
parent 762f28e4d7
commit fd04b48764
3 changed files with 12 additions and 65 deletions

View file

@ -1264,6 +1264,7 @@ def client(original_function):
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}")
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
@ -1301,7 +1302,7 @@ def client(original_function):
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
litellm.success_callback.pop(index)
litellm.failure_callback.pop(index)
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
@ -1424,10 +1425,6 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None))
else:
return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
@ -1437,6 +1434,8 @@ def client(original_function):
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
print_verbose(f"Wrapper: Completed Call, calling async_success_handler")
asyncio.run(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
my_thread = threading.Thread(
@ -1444,6 +1443,10 @@ def client(original_function):
) # don't interrupt execution of main thread
my_thread.start()
# RETURN RESULT
if "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
return result
result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
return result
except Exception as e:
@ -1544,6 +1547,7 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# RETURN RESULT