fix(utils.py): log success event for streaming

This commit is contained in:
Krrish Dholakia 2024-03-25 19:03:10 -07:00
parent a5776a3054
commit bd75498913
2 changed files with 33 additions and 12 deletions

View file

@ -651,6 +651,7 @@ async def test_async_chat_vertex_ai_stream():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.set_verbose = True
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
# test streaming # test streaming
response = await litellm.acompletion( response = await litellm.acompletion(
@ -667,6 +668,7 @@ async def test_async_chat_vertex_ai_stream():
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
await asyncio.sleep(10)
print(f"customHandler.states: {customHandler.states}") print(f"customHandler.states: {customHandler.states}")
assert ( assert (
customHandler.states.count("async_success") == 1 customHandler.states.count("async_success") == 1

View file

@ -1774,16 +1774,14 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
except Exception as e: except Exception as e:
verbose_logger.debug( print_verbose(
f"Error occurred building stream chunk: {traceback.format_exc()}" f"Error occurred building stream chunk: {traceback.format_exc()}"
) )
complete_streaming_response = None complete_streaming_response = None
else: else:
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
if complete_streaming_response is not None: if complete_streaming_response is not None:
verbose_logger.debug( print_verbose("Async success callbacks: Got a complete streaming response")
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details["async_complete_streaming_response"] = ( self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response complete_streaming_response
) )
@ -1824,7 +1822,7 @@ class Logging:
callbacks.append(callback) callbacks.append(callback)
else: else:
callbacks = litellm._async_success_callback callbacks = litellm._async_success_callback
verbose_logger.debug(f"Async success callbacks: {callbacks}") print_verbose(f"Async success callbacks: {callbacks}")
for callback in callbacks: for callback in callbacks:
# check if callback can run for this request # check if callback can run for this request
litellm_params = self.model_call_details.get("litellm_params", {}) litellm_params = self.model_call_details.get("litellm_params", {})
@ -1894,10 +1892,6 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
# print_verbose(
# f"Making async function logging call for {callback}, result={result} - {self.model_call_details}",
# logger_only=True,
# )
if self.stream: if self.stream:
if ( if (
"async_complete_streaming_response" "async_complete_streaming_response"
@ -9664,7 +9658,12 @@ class CustomStreamWrapper:
raise # Re-raise StopIteration raise # Re-raise StopIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True
return self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
).start() # log response
return processed_chunk
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
@ -9773,13 +9772,33 @@ class CustomStreamWrapper:
raise # Re-raise StopIteration raise # Re-raise StopIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True
return self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
return processed_chunk
except StopIteration: except StopIteration:
if self.sent_last_chunk == True: if self.sent_last_chunk == True:
raise StopAsyncIteration raise StopAsyncIteration
else: else:
self.sent_last_chunk = True self.sent_last_chunk = True
return self.finish_reason_handler() processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(processed_chunk,)
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
return processed_chunk
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming # Handle any exceptions that might occur during streaming