From 482180cfba29db9d36752860e04bc5de844696ad Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Mon, 3 Feb 2025 10:06:12 +0900 Subject: [PATCH 1/6] Add an configuration option to make callback logging synchronous Signed-off-by: B-Step62 --- litellm/__init__.py | 2 + litellm/batches/batch_utils.py | 11 ++- litellm/caching/caching_handler.py | 22 +++-- .../litellm_core_utils/streaming_handler.py | 85 +++++++++++-------- .../pass_through_endpoints/success_handler.py | 1 + litellm/utils.py | 18 ++-- 6 files changed, 87 insertions(+), 52 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 506ecb258e..301c882c47 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -112,6 +112,8 @@ _known_custom_logger_compatible_callbacks: List = list( callbacks: List[ Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger] ] = [] +# If true, callbacks will be executed synchronously (blocking). +sync_logging: bool = False langfuse_default_tags: Optional[List[str]] = None langsmith_batch_size: Optional[int] = None prometheus_initialize_budget_metrics: Optional[bool] = False diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index f24eda0432..acdfc5252f 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -129,10 +129,13 @@ async def _log_completed_batch( cache_hit=None, ) ) - threading.Thread( - target=logging_obj.success_handler, - args=(None, start_time, end_time), - ).start() + if litellm.sync_logging: + logging_obj.success_handler(None, start_time, end_time) + else: + threading.Thread( + target=logging_obj.success_handler, + args=(None, start_time, end_time), + ).start() async def _batch_cost_calculator( diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 40c1001732..0cd88911b1 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -280,10 +280,13 @@ class LLMCachingHandler: is_async=False, ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + if litellm.sync_logging: + logging_obj.success_handler(cached_result, start_time, end_time, True) + else: + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() cache_key = litellm.cache._get_preset_cache_key_from_kwargs( **kwargs ) @@ -449,10 +452,13 @@ class LLMCachingHandler: cached_result, start_time, end_time, cache_hit ) ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + if litellm.sync_logging: + logging_obj.success_handler(cached_result, start_time, end_time, cache_hit) + else: + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() async def _retrieve_from_cache( self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 08356fea73..b9129d41a5 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1356,22 +1356,25 @@ class CustomStreamWrapper: Disables streaming logging. """ return - ## ASYNC LOGGING - # Create an event loop for the new thread - if self.logging_loop is not None: - future = asyncio.run_coroutine_threadsafe( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit - ), - loop=self.logging_loop, - ) - future.result() - else: - asyncio.run( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit + + if not litellm.sync_logging: + ## ASYNC LOGGING + # Create an event loop for the new thread + if self.logging_loop is not None: + future = asyncio.run_coroutine_threadsafe( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ), + loop=self.logging_loop, ) - ) + future.result() + else: + asyncio.run( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ) + ) + ## SYNC LOGGING self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) @@ -1427,10 +1430,13 @@ class CustomStreamWrapper: if response is None: continue ## LOGGING - threading.Thread( - target=self.run_success_logging_and_cache_storage, - args=(response, cache_hit), - ).start() # log response + if litellm.sync_logging: + self.run_success_logging_and_cache_storage(response, cache_hit) + else: + threading.Thread( + target=self.run_success_logging_and_cache_storage, + args=(response, cache_hit), + ).start() # log response choice = response.choices[0] if isinstance(choice, StreamingChoices): self.response_uptil_now += choice.delta.get("content", "") or "" @@ -1476,10 +1482,13 @@ class CustomStreamWrapper: ) ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, - args=(response, None, None, cache_hit), - ).start() # log response + if litellm.sync_logging: + self.logging_obj.success_handler(response, None, None, cache_hit) + else: + threading.Thread( + target=self.logging_obj.success_handler, + args=(response, None, None, cache_hit), + ).start() # log response if self.sent_stream_usage is False and self.send_stream_usage is True: self.sent_stream_usage = True @@ -1492,10 +1501,13 @@ class CustomStreamWrapper: usage = calculate_total_usage(chunks=self.chunks) processed_chunk._hidden_params["usage"] = usage ## LOGGING - threading.Thread( - target=self.run_success_logging_and_cache_storage, - args=(processed_chunk, cache_hit), - ).start() # log response + if litellm.sync_logging: + self.run_success_logging_and_cache_storage(processed_chunk, cache_hit) + else: + threading.Thread( + target=self.run_success_logging_and_cache_storage, + args=(processed_chunk, cache_hit), + ).start() # log response return processed_chunk except Exception as e: traceback_exception = traceback.format_exc() @@ -1654,13 +1666,18 @@ class CustomStreamWrapper: ) ) - executor.submit( - self.logging_obj.success_handler, - complete_streaming_response, - cache_hit=cache_hit, - start_time=None, - end_time=None, - ) + if litellm.sync_logging: + self.logging_obj.success_handler( + complete_streaming_response, None, None, cache_hit + ) + else: + executor.submit( + self.logging_obj.success_handler, + complete_streaming_response, + cache_hit=cache_hit, + start_time=None, + end_time=None, + ) raise StopAsyncIteration # Re-raise StopIteration else: diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 6f112aed1f..527121ee6b 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -83,6 +83,7 @@ class PassThroughEndpointLogging: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) + thread_pool_executor.submit( logging_obj.success_handler, standard_logging_response_object, # Positional argument 1 diff --git a/litellm/utils.py b/litellm/utils.py index 7197862e3a..eff973dadf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1082,12 +1082,18 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - executor.submit( - logging_obj.success_handler, - result, - start_time, - end_time, - ) + if litellm.sync_logging: + print("sync_logging") + logging_obj.success_handler(result, start_time, end_time) + else: + print("async_logging") + executor.submit( + logging_obj.success_handler, + result, + start_time, + end_time, + ) + # RETURN RESULT update_response_metadata( result=result, From f35609677297c334e82a99762a4a4e492649810c Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Mon, 3 Feb 2025 13:46:21 +0900 Subject: [PATCH 2/6] nit Signed-off-by: B-Step62 --- litellm/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index eff973dadf..1ac0639893 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1083,10 +1083,8 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") if litellm.sync_logging: - print("sync_logging") logging_obj.success_handler(result, start_time, end_time) else: - print("async_logging") executor.submit( logging_obj.success_handler, result, From 77de86e0e7cc910083ee8b9fb42498aecb603402 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Mon, 3 Feb 2025 21:55:00 +0900 Subject: [PATCH 3/6] refactor Signed-off-by: B-Step62 --- litellm/batches/batch_utils.py | 8 +- litellm/caching/caching_handler.py | 16 +-- litellm/litellm_core_utils/litellm_logging.py | 23 ++++- .../litellm_core_utils/streaming_handler.py | 98 +++++++++---------- .../streaming_handler.py | 12 +-- .../pass_through_endpoints/success_handler.py | 2 + litellm/utils.py | 10 +- .../test_unit_tests_init_callbacks.py | 8 +- 8 files changed, 78 insertions(+), 99 deletions(-) diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index acdfc5252f..e469f23bda 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -129,13 +129,7 @@ async def _log_completed_batch( cache_hit=None, ) ) - if litellm.sync_logging: - logging_obj.success_handler(None, start_time, end_time) - else: - threading.Thread( - target=logging_obj.success_handler, - args=(None, start_time, end_time), - ).start() + logging_obj.success_handler(None, start_time, end_time) async def _batch_cost_calculator( diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 0cd88911b1..617457833d 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -280,13 +280,7 @@ class LLMCachingHandler: is_async=False, ) - if litellm.sync_logging: - logging_obj.success_handler(cached_result, start_time, end_time, True) - else: - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + logging_obj.success_handler(cached_result, start_time, end_time, True) cache_key = litellm.cache._get_preset_cache_key_from_kwargs( **kwargs ) @@ -452,13 +446,7 @@ class LLMCachingHandler: cached_result, start_time, end_time, cache_hit ) ) - if litellm.sync_logging: - logging_obj.success_handler(cached_result, start_time, end_time, cache_hit) - else: - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + logging_obj.success_handler(cached_result, start_time, end_time, cache_hit) async def _retrieve_from_cache( self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 45b63177b9..7126b7edb0 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -8,12 +8,13 @@ import os import re import subprocess import sys +import threading import time import traceback import uuid from datetime import datetime as dt_object from functools import lru_cache -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast from pydantic import BaseModel @@ -1008,7 +1009,23 @@ class Logging(LiteLLMLoggingBaseClass): except Exception as e: raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler( # noqa: PLR0915 + def success_handler(self, *args, synchronous: Optional[bool]=None): + """ + Execute the success handler function in a sync or async manner. + If synchronous argument is not provided, global `litellm.sync_logging` config is used. + """ + if synchronous is None: + synchronous = litellm.sync_logging + + if synchronous: + self._success_handler(*args) + else: + threading.Thread( + target=self._success_handler, + args=args, + ).start() + + def _success_handler( # noqa: PLR0915 self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs ): verbose_logger.debug( @@ -2151,6 +2168,8 @@ class Logging(LiteLLMLoggingBaseClass): result, start_time, end_time, + # NB: Since we already run this in a TPE, the handler itself can run sync + synchronous=True, ) def _should_run_sync_callbacks_for_async_calls(self) -> bool: diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index b9129d41a5..0b9c4a2a31 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1348,41 +1348,47 @@ class CustomStreamWrapper: """ Runs success logging in a thread and adds the response to the cache """ - if litellm.disable_streaming_logging is True: - """ - [NOT RECOMMENDED] - Set this via `litellm.disable_streaming_logging = True`. + def _run(): + if litellm.disable_streaming_logging is True: + """ + [NOT RECOMMENDED] + Set this via `litellm.disable_streaming_logging = True`. - Disables streaming logging. - """ - return + Disables streaming logging. + """ + return - if not litellm.sync_logging: - ## ASYNC LOGGING - # Create an event loop for the new thread - if self.logging_loop is not None: - future = asyncio.run_coroutine_threadsafe( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit - ), - loop=self.logging_loop, - ) - future.result() - else: - asyncio.run( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit + if not litellm.sync_logging: + ## ASYNC LOGGING + # Create an event loop for the new thread + if self.logging_loop is not None: + future = asyncio.run_coroutine_threadsafe( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ), + loop=self.logging_loop, ) + future.result() + else: + asyncio.run( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ) + ) + + ## SYNC LOGGING + self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) + + ## Sync store in cache + if self.logging_obj._llm_caching_handler is not None: + self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache( + processed_chunk ) - ## SYNC LOGGING - self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) - - ## Sync store in cache - if self.logging_obj._llm_caching_handler is not None: - self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache( - processed_chunk - ) + if litellm.sync_logging: + _run() + else: + threading.Thread(target=_run).start() def finish_reason_handler(self): model_response = self.model_response_creator() @@ -1430,13 +1436,7 @@ class CustomStreamWrapper: if response is None: continue ## LOGGING - if litellm.sync_logging: - self.run_success_logging_and_cache_storage(response, cache_hit) - else: - threading.Thread( - target=self.run_success_logging_and_cache_storage, - args=(response, cache_hit), - ).start() # log response + self.run_success_logging_and_cache_storage(response, cache_hit) choice = response.choices[0] if isinstance(choice, StreamingChoices): self.response_uptil_now += choice.delta.get("content", "") or "" @@ -1482,13 +1482,7 @@ class CustomStreamWrapper: ) ## LOGGING - if litellm.sync_logging: - self.logging_obj.success_handler(response, None, None, cache_hit) - else: - threading.Thread( - target=self.logging_obj.success_handler, - args=(response, None, None, cache_hit), - ).start() # log response + self.logging_obj.success_handler(response, None, None, cache_hit) if self.sent_stream_usage is False and self.send_stream_usage is True: self.sent_stream_usage = True @@ -1501,13 +1495,7 @@ class CustomStreamWrapper: usage = calculate_total_usage(chunks=self.chunks) processed_chunk._hidden_params["usage"] = usage ## LOGGING - if litellm.sync_logging: - self.run_success_logging_and_cache_storage(processed_chunk, cache_hit) - else: - threading.Thread( - target=self.run_success_logging_and_cache_storage, - args=(processed_chunk, cache_hit), - ).start() # log response + self.run_success_logging_and_cache_storage(processed_chunk, cache_hit) return processed_chunk except Exception as e: traceback_exception = traceback.format_exc() @@ -1674,9 +1662,11 @@ class CustomStreamWrapper: executor.submit( self.logging_obj.success_handler, complete_streaming_response, - cache_hit=cache_hit, - start_time=None, - end_time=None, + None, + None, + cache_hit, + # NB: We already run this in a TPE so the handler itself should run sync + synchronous=True, ) raise StopAsyncIteration # Re-raise StopIteration diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index b022bf1d25..1e60d27483 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -123,15 +123,9 @@ class PassThroughStreamingHandler: standard_logging_response_object = StandardPassThroughResponseObject( response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" ) - threading.Thread( - target=litellm_logging_obj.success_handler, - args=( - standard_logging_response_object, - start_time, - end_time, - False, - ), - ).start() + litellm_logging_obj.success_handler( + standard_logging_response_object, start_time, end_time, False + ) await litellm_logging_obj.async_success_handler( result=standard_logging_response_object, start_time=start_time, diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 527121ee6b..addfcbe371 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -91,6 +91,8 @@ class PassThroughEndpointLogging: end_time, # Positional argument 3 cache_hit, # Positional argument 4 **kwargs, # Unpacked keyword arguments + # NB: Since we already run this in a TPE, the handler itself can run sync + synchronous=True, ) await logging_obj.async_success_handler( diff --git a/litellm/utils.py b/litellm/utils.py index 1ac0639893..9421726182 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1082,15 +1082,7 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - if litellm.sync_logging: - logging_obj.success_handler(result, start_time, end_time) - else: - executor.submit( - logging_obj.success_handler, - result, - start_time, - end_time, - ) + logging_obj.success_handler(result, start_time, end_time) # RETURN RESULT update_response_metadata( diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 7d77e26aaf..d6d03c4d11 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -262,7 +262,7 @@ def test_dynamic_logging_global_callback(): try: litellm_logging.success_handler( - result=ModelResponse( + ModelResponse( id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", created=1732306261, model="claude-3-opus-20240229", @@ -288,9 +288,9 @@ def test_dynamic_logging_global_callback(): prompt_tokens_details=None, ), ), - start_time=datetime.now(), - end_time=datetime.now(), - cache_hit=False, + datetime.now(), + datetime.now(), + False, ) except Exception as e: print(f"Error: {e}") From 5d8b359384a92d251a0f052dccd50706fcb1bef7 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Fri, 28 Feb 2025 18:39:15 +0900 Subject: [PATCH 4/6] comments and small fix Signed-off-by: B-Step62 --- litellm/integrations/mlflow.py | 6 +- litellm/litellm_core_utils/litellm_logging.py | 15 ++- .../litellm_core_utils/streaming_handler.py | 25 ++--- litellm/utils.py | 31 +++--- .../test_unit_tests_init_callbacks.py | 99 ++++++++++++++++++- 5 files changed, 141 insertions(+), 35 deletions(-) diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py index 193d1c4ea2..4bcc7271ce 100644 --- a/litellm/integrations/mlflow.py +++ b/litellm/integrations/mlflow.py @@ -122,7 +122,11 @@ class MlflowLogger(CustomLogger): # If this is the final chunk, end the span. The final chunk # has complete_streaming_response that gathers the full response. - if final_response := kwargs.get("complete_streaming_response"): + final_response = ( + kwargs.get("complete_streaming_response") + or kwargs.get("async_complete_streaming_response") + ) + if final_response: end_time_ns = int(end_time.timestamp() * 1e9) self._extract_and_set_chat_attributes(span, kwargs, final_response) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 6c9d9409e3..d528b92c14 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1012,7 +1012,15 @@ class Logging(LiteLLMLoggingBaseClass): except Exception as e: raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler(self, *args, synchronous: Optional[bool]=None): + def success_handler( + self, + result=None, + start_time=None, + end_time=None, + cache_hit=None, + synchronous=None, + **kwargs + ): """ Execute the success handler function in a sync or async manner. If synchronous argument is not provided, global `litellm.sync_logging` config is used. @@ -1021,11 +1029,12 @@ class Logging(LiteLLMLoggingBaseClass): synchronous = litellm.sync_logging if synchronous: - self._success_handler(*args) + self._success_handler(result, start_time, end_time, cache_hit, **kwargs) else: threading.Thread( target=self._success_handler, - args=args, + args=(result, start_time, end_time, cache_hit), + kwargs=kwargs, ).start() def _success_handler( # noqa: PLR0915 diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 0c6bbc4e37..1a27b703e4 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1696,26 +1696,21 @@ class CustomStreamWrapper: self.sent_stream_usage = True return response - asyncio.create_task( - self.logging_obj.async_success_handler( - complete_streaming_response, - cache_hit=cache_hit, - start_time=None, - end_time=None, - ) - ) + logging_params = dict( + result=complete_streaming_response, + cache_hit=cache_hit, + start_time=None, + end_time=None, + ) if litellm.sync_logging: - self.logging_obj.success_handler( - complete_streaming_response, None, None, cache_hit - ) + await self.logging_obj.async_success_handler(**logging_params) + self.logging_obj.success_handler(**logging_params, synchronous=True) else: + asyncio.create_task(self.logging_obj.async_success_handler(**logging_params)) executor.submit( self.logging_obj.success_handler, - complete_streaming_response, - None, - None, - cache_hit, + **logging_params, # NB: We already run this in a TPE so the handler itself should run sync synchronous=True, ) diff --git a/litellm/utils.py b/litellm/utils.py index 5aea3b8d0b..2c878807f1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -749,9 +749,7 @@ async def _client_async_logging_helper( f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" ) # check if user does not want this to be logged - asyncio.create_task( - logging_obj.async_success_handler(result, start_time, end_time) - ) + await logging_obj.async_success_handler(result, start_time, end_time) logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, @@ -1121,7 +1119,13 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - logging_obj.success_handler(result, start_time, end_time) + if litellm.sync_logging: + logging_obj.success_handler(result, start_time, end_time) + else: + executor.submit( + # NB: We already run this in a TPE so the handler itself should run sync + logging_obj.success_handler, result, start_time, end_time, synchronous=True, + ) # RETURN RESULT update_response_metadata( @@ -1288,15 +1292,18 @@ def client(original_function): # noqa: PLR0915 ) # LOG SUCCESS - handle streaming success logging in the _next_ object - asyncio.create_task( - _client_async_logging_helper( - logging_obj=logging_obj, - result=result, - start_time=start_time, - end_time=end_time, - is_completion_with_fallbacks=is_completion_with_fallbacks, - ) + async_logging_params = dict( + logging_obj=logging_obj, + result=result, + start_time=start_time, + end_time=end_time, + is_completion_with_fallbacks=is_completion_with_fallbacks, ) + if litellm.sync_logging: + await _client_async_logging_helper(**async_logging_params) + else: + asyncio.create_task(_client_async_logging_helper(**async_logging_params)) + logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 3a266048c4..b340fbde4a 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -2,8 +2,11 @@ import json import os import sys from datetime import datetime +import threading from unittest.mock import AsyncMock +from litellm.litellm_core_utils.dd_tracing import contextmanager + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path @@ -264,7 +267,7 @@ def test_dynamic_logging_global_callback(): try: litellm_logging.success_handler( - ModelResponse( + result=ModelResponse( id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", created=1732306261, model="claude-3-opus-20240229", @@ -290,9 +293,10 @@ def test_dynamic_logging_global_callback(): prompt_tokens_details=None, ), ), - datetime.now(), - datetime.now(), - False, + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + synchronous=True, ) except Exception as e: print(f"Error: {e}") @@ -319,3 +323,90 @@ def test_get_combined_callback_list(): assert "lago" in _logging.get_combined_callback_list( dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] ) + + + +@pytest.mark.parametrize("sync_logging", [True, False]) +def test_success_handler_sync_async(sync_logging): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) + + litellm.sync_logging = sync_logging + + + @contextmanager + def patch_thread(): + """ + A context manager to collect threads started for logging handlers. + This is done by monkey-patching the start() method of threading.Thread. + Note that failure handlers are executed synchronously, so we don't need to patch them. + """ + original = threading.Thread.start + logging_threads = [] + + def _patched_start(self, *args, **kwargs): + logging_threads.append(self) + return original(self, *args, **kwargs) + + threading.Thread.start = _patched_start + try: + yield logging_threads + finally: + threading.Thread.start = original + + result = ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + ) + + + with patch.object(cl, "log_success_event") as mock_log_success_event: + litellm.success_callback = [cl] + + with patch_thread() as logging_threads: + litellm_logging.success_handler( + result=result, + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + + if sync_logging: + mock_log_success_event.assert_called_once() + assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"] + assert logging_threads == [] + else: + mock_log_success_event.assert_not_called() + assert len(logging_threads) == 1 + + # Wait for the thread to finish + logging_threads[0].join() + mock_log_success_event.assert_called_once() + assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"] From 53443a8d5a925b1ac04fc1855ae326ef9592c6de Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Thu, 20 Mar 2025 02:06:34 +0900 Subject: [PATCH 5/6] refactor Signed-off-by: B-Step62 --- litellm/litellm_core_utils/litellm_logging.py | 14 ++++-- .../litellm_core_utils/streaming_handler.py | 22 ++------- .../proxy/_experimental/out/onboarding.html | 1 - .../test_unit_tests_init_callbacks.py | 49 ++++++------------- 4 files changed, 28 insertions(+), 58 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/onboarding.html diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 2deb3d4f07..e7c34d0ede 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1103,11 +1103,15 @@ class Logging(LiteLLMLoggingBaseClass): if synchronous: self._success_handler(result, start_time, end_time, cache_hit, **kwargs) else: - threading.Thread( - target=self._success_handler, - args=(result, start_time, end_time, cache_hit), - kwargs=kwargs, - ).start() + executor.submit( + self._success_handler, + result, + start_time, + end_time, + cache_hit, + **kwargs, + ) + def _success_handler( # noqa: PLR0915 self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 585bfffd13..c994018b9d 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1534,7 +1534,7 @@ class CustomStreamWrapper: if litellm.sync_logging: _run() else: - executor.submit((target=_run).start() + executor.submit(_run) def finish_reason_handler(self): model_response = self.model_response_creator() @@ -1634,26 +1634,10 @@ class CustomStreamWrapper: cache_hit=cache_hit, ) logging_result = complete_streaming_response.model_copy(deep=True) - executor.submit( - self.logging_obj.success_handler, - complete_streaming_response.model_copy(deep=True), - None, - None, - cache_hit, - ) else: logging_result = response - - if litellm.sync_logging: - self.logging_obj.success_handler(logging_result, None, None, cache_hit) - else: - executor.submit( - self.logging_obj.success_handler, - logging_result, - None, - None, - cache_hit, - ) + + self.logging_obj.success_handler(logging_result, None, None, cache_hit) if self.sent_stream_usage is False and self.send_stream_usage is True: self.sent_stream_usage = True diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 82f43619df..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index b340fbde4a..00354150e1 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -6,6 +6,7 @@ import threading from unittest.mock import AsyncMock from litellm.litellm_core_utils.dd_tracing import contextmanager +from litellm.utils import executor sys.path.insert( 0, os.path.abspath("../..") @@ -346,27 +347,6 @@ def test_success_handler_sync_async(sync_logging): litellm.sync_logging = sync_logging - - @contextmanager - def patch_thread(): - """ - A context manager to collect threads started for logging handlers. - This is done by monkey-patching the start() method of threading.Thread. - Note that failure handlers are executed synchronously, so we don't need to patch them. - """ - original = threading.Thread.start - logging_threads = [] - - def _patched_start(self, *args, **kwargs): - logging_threads.append(self) - return original(self, *args, **kwargs) - - threading.Thread.start = _patched_start - try: - yield logging_threads - finally: - threading.Thread.start = original - result = ModelResponse( id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", created=1732306261, @@ -387,26 +367,29 @@ def test_success_handler_sync_async(sync_logging): ) - with patch.object(cl, "log_success_event") as mock_log_success_event: + with ( + patch.object(cl, "log_success_event") as mock_log_success_event, + patch.object( + executor, "submit", + side_effect=lambda *args: args[0](*args[1:]) + ) as mock_executor, + ): litellm.success_callback = [cl] - with patch_thread() as logging_threads: - litellm_logging.success_handler( - result=result, - start_time=datetime.now(), - end_time=datetime.now(), - cache_hit=False, - ) + litellm_logging.success_handler( + result=result, + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) if sync_logging: + mock_executor.assert_not_called() mock_log_success_event.assert_called_once() assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"] - assert logging_threads == [] else: - mock_log_success_event.assert_not_called() - assert len(logging_threads) == 1 # Wait for the thread to finish - logging_threads[0].join() + mock_executor.assert_called_once() mock_log_success_event.assert_called_once() assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"] From 57f1d436e1b3aa99500767a47e6412f15e35b11c Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Thu, 20 Mar 2025 02:06:34 +0900 Subject: [PATCH 6/6] clean up Signed-off-by: B-Step62 --- litellm/litellm_core_utils/litellm_logging.py | 9 +-------- litellm/litellm_core_utils/streaming_handler.py | 9 ++------- litellm/responses/streaming_iterator.py | 3 +-- litellm/utils.py | 8 +------- 4 files changed, 5 insertions(+), 24 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index e7c34d0ede..53fb33dcbb 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2264,14 +2264,7 @@ class Logging(LiteLLMLoggingBaseClass): if self._should_run_sync_callbacks_for_async_calls() is False: return - executor.submit( - self.success_handler, - result, - start_time, - end_time, - # NB: Since we already run this in a TPE, the handler itself can run sync - synchronous=True, - ) + self.success_handler(result, start_time, end_time) def _should_run_sync_callbacks_for_async_calls(self) -> bool: """ diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index c994018b9d..ae1a9fc66c 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1810,15 +1810,10 @@ class CustomStreamWrapper: ) if litellm.sync_logging: await self.logging_obj.async_success_handler(**logging_params) - self.logging_obj.success_handler(**logging_params, synchronous=True) else: asyncio.create_task(self.logging_obj.async_success_handler(**logging_params)) - executor.submit( - self.logging_obj.success_handler, - **logging_params, - # NB: We already run this in a TPE so the handler itself should run sync - synchronous=True, - ) + + self.logging_obj.success_handler(**logging_params) raise StopAsyncIteration # Re-raise StopIteration else: diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index c016e71e7e..804f6fce6a 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -140,8 +140,7 @@ class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): ) ) - executor.submit( - self.logging_obj.success_handler, + self.logging_obj.success_handler( result=self.completed_response, cache_hit=None, start_time=self.start_time, diff --git a/litellm/utils.py b/litellm/utils.py index b923c0e036..2394e031b2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1161,13 +1161,7 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - if litellm.sync_logging: - logging_obj.success_handler(result, start_time, end_time) - else: - executor.submit( - # NB: We already run this in a TPE so the handler itself should run sync - logging_obj.success_handler, result, start_time, end_time, synchronous=True, - ) + logging_obj.success_handler(result, start_time, end_time) # RETURN RESULT update_response_metadata(