diff --git a/litellm/__init__.py b/litellm/__init__.py index 59c8c78eb9..d96781cba3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -123,6 +123,8 @@ _known_custom_logger_compatible_callbacks: List = list( callbacks: List[ Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger] ] = [] +# If true, callbacks will be executed synchronously (blocking). +sync_logging: bool = False langfuse_default_tags: Optional[List[str]] = None langsmith_batch_size: Optional[int] = None prometheus_initialize_budget_metrics: Optional[bool] = False diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 14278de9cd..2bfb48cc68 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -277,10 +277,7 @@ class LLMCachingHandler: is_async=False, ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + logging_obj.success_handler(cached_result, start_time, end_time, True) cache_key = litellm.cache._get_preset_cache_key_from_kwargs( **kwargs ) @@ -446,10 +443,7 @@ class LLMCachingHandler: cached_result, start_time, end_time, cache_hit ) ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + logging_obj.success_handler(cached_result, start_time, end_time, cache_hit) async def _retrieve_from_cache( self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] diff --git a/litellm/integrations/mlflow.py b/litellm/integrations/mlflow.py index e7a458accf..2812abeb6a 100644 --- a/litellm/integrations/mlflow.py +++ b/litellm/integrations/mlflow.py @@ -125,7 +125,11 @@ class MlflowLogger(CustomLogger): # If this is the final chunk, end the span. The final chunk # has complete_streaming_response that gathers the full response. - if final_response := kwargs.get("complete_streaming_response"): + final_response = ( + kwargs.get("complete_streaming_response") + or kwargs.get("async_complete_streaming_response") + ) + if final_response: end_time_ns = int(end_time.timestamp() * 1e9) self._extract_and_set_chat_attributes(span, kwargs, final_response) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 77d4fd7d5d..877b04215f 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -8,12 +8,13 @@ import os import re import subprocess import sys +import threading import time import traceback import uuid from datetime import datetime as dt_object from functools import lru_cache -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast from pydantic import BaseModel @@ -1226,7 +1227,36 @@ class Logging(LiteLLMLoggingBaseClass): except Exception as e: raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler( # noqa: PLR0915 + def success_handler( + self, + result=None, + start_time=None, + end_time=None, + cache_hit=None, + synchronous=None, + **kwargs + ): + """ + Execute the success handler function in a sync or async manner. + If synchronous argument is not provided, global `litellm.sync_logging` config is used. + """ + if synchronous is None: + synchronous = litellm.sync_logging + + if synchronous: + self._success_handler(result, start_time, end_time, cache_hit, **kwargs) + else: + executor.submit( + self._success_handler, + result, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + + def _success_handler( # noqa: PLR0915 self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs ): verbose_logger.debug( @@ -2376,12 +2406,7 @@ class Logging(LiteLLMLoggingBaseClass): if self._should_run_sync_callbacks_for_async_calls() is False: return - executor.submit( - self.success_handler, - result, - start_time, - end_time, - ) + self.success_handler(result, start_time, end_time) def _should_run_sync_callbacks_for_async_calls(self) -> bool: """ diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index ec20a1ad4c..560724a5a6 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1445,32 +1445,47 @@ class CustomStreamWrapper: """ Runs success logging in a thread and adds the response to the cache """ - if litellm.disable_streaming_logging is True: - """ - [NOT RECOMMENDED] - Set this via `litellm.disable_streaming_logging = True`. + def _run(): + if litellm.disable_streaming_logging is True: + """ + [NOT RECOMMENDED] + Set this via `litellm.disable_streaming_logging = True`. - Disables streaming logging. - """ - return - ## ASYNC LOGGING - # Create an event loop for the new thread - if self.logging_loop is not None: - future = asyncio.run_coroutine_threadsafe( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit - ), - loop=self.logging_loop, - ) - future.result() - else: - asyncio.run( - self.logging_obj.async_success_handler( - processed_chunk, None, None, cache_hit + Disables streaming logging. + """ + return + + if not litellm.sync_logging: + ## ASYNC LOGGING + # Create an event loop for the new thread + if self.logging_loop is not None: + future = asyncio.run_coroutine_threadsafe( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ), + loop=self.logging_loop, + ) + future.result() + else: + asyncio.run( + self.logging_obj.async_success_handler( + processed_chunk, None, None, cache_hit + ) + ) + + ## SYNC LOGGING + self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) + + ## Sync store in cache + if self.logging_obj._llm_caching_handler is not None: + self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache( + processed_chunk ) - ) - ## SYNC LOGGING - self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) + + if litellm.sync_logging: + _run() + else: + executor.submit(_run) def finish_reason_handler(self): model_response = self.model_response_creator() @@ -1522,11 +1537,8 @@ class CustomStreamWrapper: completion_start_time=datetime.datetime.now() ) ## LOGGING - executor.submit( - self.run_success_logging_and_cache_storage, - response, - cache_hit, - ) # log response + self.run_success_logging_and_cache_storage(response, cache_hit) + choice = response.choices[0] if isinstance(choice, StreamingChoices): self.response_uptil_now += choice.delta.get("content", "") or "" @@ -1576,21 +1588,12 @@ class CustomStreamWrapper: ), cache_hit=cache_hit, ) - executor.submit( - self.logging_obj.success_handler, - complete_streaming_response.model_copy(deep=True), - None, - None, - cache_hit, - ) + logging_result = complete_streaming_response.model_copy(deep=True) else: - executor.submit( - self.logging_obj.success_handler, - response, - None, - None, - cache_hit, - ) + logging_result = response + + self.logging_obj.success_handler(logging_result, None, None, cache_hit) + if self.sent_stream_usage is False and self.send_stream_usage is True: self.sent_stream_usage = True return response @@ -1602,11 +1605,7 @@ class CustomStreamWrapper: usage = calculate_total_usage(chunks=self.chunks) processed_chunk._hidden_params["usage"] = usage ## LOGGING - executor.submit( - self.run_success_logging_and_cache_storage, - processed_chunk, - cache_hit, - ) # log response + self.run_success_logging_and_cache_storage(processed_chunk, cache_hit) return processed_chunk except Exception as e: traceback_exception = traceback.format_exc() @@ -1762,22 +1761,19 @@ class CustomStreamWrapper: self.sent_stream_usage = True return response - asyncio.create_task( - self.logging_obj.async_success_handler( - complete_streaming_response, - cache_hit=cache_hit, - start_time=None, - end_time=None, - ) - ) - executor.submit( - self.logging_obj.success_handler, - complete_streaming_response, + logging_params = dict( + result=complete_streaming_response, cache_hit=cache_hit, start_time=None, end_time=None, ) + if litellm.sync_logging: + await self.logging_obj.async_success_handler(**logging_params) + else: + asyncio.create_task(self.logging_obj.async_success_handler(**logging_params)) + + self.logging_obj.success_handler(**logging_params) raise StopAsyncIteration # Re-raise StopIteration else: diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index d4260a0300..4d2414d9a2 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -122,15 +122,9 @@ class PassThroughStreamingHandler: standard_logging_response_object = StandardPassThroughResponseObject( response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" ) - threading.Thread( - target=litellm_logging_obj.success_handler, - args=( - standard_logging_response_object, - start_time, - end_time, - False, - ), - ).start() + litellm_logging_obj.success_handler( + standard_logging_response_object, start_time, end_time, False + ) await litellm_logging_obj.async_success_handler( result=standard_logging_response_object, start_time=start_time, diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index a111fbec09..b2d0ad2257 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -170,8 +170,7 @@ class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): ) ) - executor.submit( - self.logging_obj.success_handler, + self.logging_obj.success_handler( result=self.completed_response, cache_hit=None, start_time=self.start_time, diff --git a/litellm/utils.py b/litellm/utils.py index 98a9c34b47..724bf674af 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -808,9 +808,7 @@ async def _client_async_logging_helper( f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" ) # check if user does not want this to be logged - asyncio.create_task( - logging_obj.async_success_handler(result, start_time, end_time) - ) + await logging_obj.async_success_handler(result, start_time, end_time) logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, @@ -1183,12 +1181,8 @@ def client(original_function): # noqa: PLR0915 # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") - executor.submit( - logging_obj.success_handler, - result, - start_time, - end_time, - ) + logging_obj.success_handler(result, start_time, end_time) + # RETURN RESULT update_response_metadata( result=result, @@ -1357,15 +1351,18 @@ def client(original_function): # noqa: PLR0915 ) # LOG SUCCESS - handle streaming success logging in the _next_ object - asyncio.create_task( - _client_async_logging_helper( - logging_obj=logging_obj, - result=result, - start_time=start_time, - end_time=end_time, - is_completion_with_fallbacks=is_completion_with_fallbacks, - ) + async_logging_params = dict( + logging_obj=logging_obj, + result=result, + start_time=start_time, + end_time=end_time, + is_completion_with_fallbacks=is_completion_with_fallbacks, ) + if litellm.sync_logging: + await _client_async_logging_helper(**async_logging_params) + else: + asyncio.create_task(_client_async_logging_helper(**async_logging_params)) + logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 445c773d99..ff64f2d6a9 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -2,8 +2,12 @@ import json import os import sys from datetime import datetime +import threading from unittest.mock import AsyncMock +from litellm.litellm_core_utils.dd_tracing import contextmanager +from litellm.utils import executor + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path @@ -297,6 +301,7 @@ def test_dynamic_logging_global_callback(): start_time=datetime.now(), end_time=datetime.now(), cache_hit=False, + synchronous=True, ) except Exception as e: print(f"Error: {e}") @@ -323,3 +328,72 @@ def test_get_combined_callback_list(): assert "lago" in _logging.get_combined_callback_list( dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] ) + + + +@pytest.mark.parametrize("sync_logging", [True, False]) +def test_success_handler_sync_async(sync_logging): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) + + litellm.sync_logging = sync_logging + + result = ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + ) + + + with ( + patch.object(cl, "log_success_event") as mock_log_success_event, + patch.object( + executor, "submit", + side_effect=lambda *args: args[0](*args[1:]) + ) as mock_executor, + ): + litellm.success_callback = [cl] + + litellm_logging.success_handler( + result=result, + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + + if sync_logging: + mock_executor.assert_not_called() + mock_log_success_event.assert_called_once() + assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"] + else: + + # Wait for the thread to finish + mock_executor.assert_called_once() + mock_log_success_event.assert_called_once() + assert "standard_logging_object" in mock_log_success_event.call_args.kwargs["kwargs"]