From fd04b48764856a20df4b0c1936db4ab61048efb4 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 7 Dec 2023 18:09:57 -0800 Subject: [PATCH] (feat) async callbacks with litellm.completion() --- litellm/proxy/custom_callbacks.py | 57 ----------------------------- litellm/tests/test_custom_logger.py | 6 +-- litellm/utils.py | 14 ++++--- 3 files changed, 12 insertions(+), 65 deletions(-) diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index 18aea8a97..08947a066 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -35,68 +35,11 @@ class MyCustomHandler(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async Success!") - # log: key, user, model, prompt, response, tokens, cost - # Access kwargs passed to litellm.completion() - model = kwargs.get("model", None) - messages = kwargs.get("messages", None) - user = kwargs.get("user", None) - - # Access litellm_params passed to litellm.completion(), example access `metadata` - litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here - - # Calculate cost using litellm.completion_cost() - cost = litellm.completion_cost(completion_response=response_obj) - response = response_obj - # tokens used in response - usage = response_obj["usage"] - - print( - f""" - Model: {model}, - Messages: {messages}, - User: {user}, - Usage: {usage}, - Cost: {cost}, - Response: {response} - Proxy Metadata: {metadata} - """ - ) return async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: print(f"On Async Failure !") - print("\nkwargs", kwargs) - # Access kwargs passed to litellm.completion() - model = kwargs.get("model", None) - messages = kwargs.get("messages", None) - user = kwargs.get("user", None) - - # Access litellm_params passed to litellm.completion(), example access `metadata` - litellm_params = kwargs.get("litellm_params", {}) - metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here - - # Acess Exceptions & Traceback - exception_event = kwargs.get("exception", None) - traceback_event = kwargs.get("traceback_exception", None) - - # Calculate cost using litellm.completion_cost() - cost = litellm.completion_cost(completion_response=response_obj) - print("now checking response obj") - - print( - f""" - Model: {model}, - Messages: {messages}, - User: {user}, - Cost: {cost}, - Response: {response_obj} - Proxy Metadata: {metadata} - Exception: {exception_event} - Traceback: {traceback_event} - """ - ) except Exception as e: print(f"Exception: {e}") diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 0702cb52c..333293596 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -142,8 +142,8 @@ def test_async_custom_handler(): assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 print("Passed setting async failure") - async def test_2(): - response = await litellm.acompletion( + def test_2(): + response = litellm.completion( model="gpt-3.5-turbo", messages=[{ "role": "user", @@ -152,7 +152,7 @@ def test_async_custom_handler(): ) print("\n response", response) assert customHandler2.async_success == False - asyncio.run(test_2()) + test_2() assert customHandler2.async_success == True, "async success is not set to True even after success" assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo" diff --git a/litellm/utils.py b/litellm/utils.py index 4b64caa8b..879447194 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1264,6 +1264,7 @@ def client(original_function): litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) + print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}") if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -1301,7 +1302,7 @@ def client(original_function): # Pop the async items from success_callback in reverse order to avoid index issues for index in reversed(removed_async_items): - litellm.success_callback.pop(index) + litellm.failure_callback.pop(index) if add_breadcrumb: add_breadcrumb( category="litellm.llm_call", @@ -1424,10 +1425,6 @@ def client(original_function): return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) else: return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: - return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: - return result ### POST-CALL RULES ### post_call_processing(original_response=result, model=model) @@ -1437,6 +1434,8 @@ def client(original_function): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated + print_verbose(f"Wrapper: Completed Call, calling async_success_handler") + asyncio.run(logging_obj.async_success_handler(result, start_time, end_time)) threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() my_thread = threading.Thread( @@ -1444,6 +1443,10 @@ def client(original_function): ) # don't interrupt execution of main thread my_thread.start() # RETURN RESULT + if "acompletion" in kwargs and kwargs["acompletion"] == True: + return result + elif "aembedding" in kwargs and kwargs["aembedding"] == True: + return result result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai return result except Exception as e: @@ -1544,6 +1547,7 @@ def client(original_function): if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object + print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler") asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT