forked from phoenix/litellm-mirror
(feat) async callbacks with litellm.completion()
This commit is contained in:
parent
762f28e4d7
commit
fd04b48764
3 changed files with 12 additions and 65 deletions
|
@ -35,68 +35,11 @@ class MyCustomHandler(CustomLogger):
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Async Success!")
|
print(f"On Async Success!")
|
||||||
# log: key, user, model, prompt, response, tokens, cost
|
|
||||||
# Access kwargs passed to litellm.completion()
|
|
||||||
model = kwargs.get("model", None)
|
|
||||||
messages = kwargs.get("messages", None)
|
|
||||||
user = kwargs.get("user", None)
|
|
||||||
|
|
||||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
|
||||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
|
||||||
|
|
||||||
# Calculate cost using litellm.completion_cost()
|
|
||||||
cost = litellm.completion_cost(completion_response=response_obj)
|
|
||||||
response = response_obj
|
|
||||||
# tokens used in response
|
|
||||||
usage = response_obj["usage"]
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
Model: {model},
|
|
||||||
Messages: {messages},
|
|
||||||
User: {user},
|
|
||||||
Usage: {usage},
|
|
||||||
Cost: {cost},
|
|
||||||
Response: {response}
|
|
||||||
Proxy Metadata: {metadata}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
print(f"On Async Failure !")
|
print(f"On Async Failure !")
|
||||||
print("\nkwargs", kwargs)
|
|
||||||
# Access kwargs passed to litellm.completion()
|
|
||||||
model = kwargs.get("model", None)
|
|
||||||
messages = kwargs.get("messages", None)
|
|
||||||
user = kwargs.get("user", None)
|
|
||||||
|
|
||||||
# Access litellm_params passed to litellm.completion(), example access `metadata`
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
|
||||||
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
|
|
||||||
|
|
||||||
# Acess Exceptions & Traceback
|
|
||||||
exception_event = kwargs.get("exception", None)
|
|
||||||
traceback_event = kwargs.get("traceback_exception", None)
|
|
||||||
|
|
||||||
# Calculate cost using litellm.completion_cost()
|
|
||||||
cost = litellm.completion_cost(completion_response=response_obj)
|
|
||||||
print("now checking response obj")
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"""
|
|
||||||
Model: {model},
|
|
||||||
Messages: {messages},
|
|
||||||
User: {user},
|
|
||||||
Cost: {cost},
|
|
||||||
Response: {response_obj}
|
|
||||||
Proxy Metadata: {metadata}
|
|
||||||
Exception: {exception_event}
|
|
||||||
Traceback: {traceback_event}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {e}")
|
print(f"Exception: {e}")
|
||||||
|
|
||||||
|
|
|
@ -142,8 +142,8 @@ def test_async_custom_handler():
|
||||||
assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||||
print("Passed setting async failure")
|
print("Passed setting async failure")
|
||||||
|
|
||||||
async def test_2():
|
def test_2():
|
||||||
response = await litellm.acompletion(
|
response = litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{
|
messages=[{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -152,7 +152,7 @@ def test_async_custom_handler():
|
||||||
)
|
)
|
||||||
print("\n response", response)
|
print("\n response", response)
|
||||||
assert customHandler2.async_success == False
|
assert customHandler2.async_success == False
|
||||||
asyncio.run(test_2())
|
test_2()
|
||||||
assert customHandler2.async_success == True, "async success is not set to True even after success"
|
assert customHandler2.async_success == True, "async success is not set to True even after success"
|
||||||
assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
|
@ -1264,6 +1264,7 @@ def client(original_function):
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback)
|
||||||
if callback not in litellm._async_failure_callback:
|
if callback not in litellm._async_failure_callback:
|
||||||
litellm._async_failure_callback.append(callback)
|
litellm._async_failure_callback.append(callback)
|
||||||
|
print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}")
|
||||||
if (
|
if (
|
||||||
len(litellm.input_callback) > 0
|
len(litellm.input_callback) > 0
|
||||||
or len(litellm.success_callback) > 0
|
or len(litellm.success_callback) > 0
|
||||||
|
@ -1301,7 +1302,7 @@ def client(original_function):
|
||||||
|
|
||||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
litellm.success_callback.pop(index)
|
litellm.failure_callback.pop(index)
|
||||||
if add_breadcrumb:
|
if add_breadcrumb:
|
||||||
add_breadcrumb(
|
add_breadcrumb(
|
||||||
category="litellm.llm_call",
|
category="litellm.llm_call",
|
||||||
|
@ -1424,10 +1425,6 @@ def client(original_function):
|
||||||
return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None))
|
return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None))
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
|
|
||||||
return result
|
|
||||||
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
|
||||||
return result
|
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model)
|
post_call_processing(original_response=result, model=model)
|
||||||
|
@ -1437,6 +1434,8 @@ def client(original_function):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
|
print_verbose(f"Wrapper: Completed Call, calling async_success_handler")
|
||||||
|
asyncio.run(logging_obj.async_success_handler(result, start_time, end_time))
|
||||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
# threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
my_thread = threading.Thread(
|
my_thread = threading.Thread(
|
||||||
|
@ -1444,6 +1443,10 @@ def client(original_function):
|
||||||
) # don't interrupt execution of main thread
|
) # don't interrupt execution of main thread
|
||||||
my_thread.start()
|
my_thread.start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
if "acompletion" in kwargs and kwargs["acompletion"] == True:
|
||||||
|
return result
|
||||||
|
elif "aembedding" in kwargs and kwargs["aembedding"] == True:
|
||||||
|
return result
|
||||||
result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1544,6 +1547,7 @@ def client(original_function):
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||||
|
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
|
||||||
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue