From 541326731fb930889b0afed2d7541d9a7efe7e46 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Nov 2024 02:00:45 +0530 Subject: [PATCH] fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically --- litellm/litellm_core_utils/litellm_logging.py | 82 ++++++------------- .../test_unit_tests_init_callbacks.py | 72 ++++++++++++++++ 2 files changed, 96 insertions(+), 58 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index a9ded458d..298e28974 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -934,19 +934,10 @@ class Logging: status="success", ) ) - if self.dynamic_success_callbacks is not None and isinstance( - self.dynamic_success_callbacks, list - ): - callbacks = self.dynamic_success_callbacks - ## keep the internal functions ## - for callback in litellm.success_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) ## REDACT MESSAGES ## result = redact_message_input_output_from_logging( @@ -1469,21 +1460,10 @@ class Logging: status="success", ) ) - if self.dynamic_async_success_callbacks is not None and isinstance( - self.dynamic_async_success_callbacks, list - ): - callbacks = self.dynamic_async_success_callbacks - ## keep the internal functions ## - for callback in litellm._async_success_callback: - callback_name = "" - if isinstance(callback, CustomLogger): - callback_name = callback.__class__.__name__ - if callable(callback): - callback_name = callback.__name__ - if "_PROXY_" in callback_name: - callbacks.append(callback) - else: - callbacks = litellm._async_success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_success_callbacks, + global_callbacks=litellm._async_success_callback, + ) result = redact_message_input_output_from_logging( model_call_details=( @@ -1750,21 +1730,10 @@ class Logging: start_time=start_time, end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_failure_callbacks is not None and isinstance( - self.dynamic_failure_callbacks, list - ): - callbacks = self.dynamic_failure_callbacks - ## keep the internal functions ## - for callback in litellm.failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_failure_callbacks, + global_callbacks=litellm.failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created @@ -1947,21 +1916,10 @@ class Logging: end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_async_failure_callbacks is not None and isinstance( - self.dynamic_async_failure_callbacks, list - ): - callbacks = self.dynamic_async_failure_callbacks - ## keep the internal functions ## - for callback in litellm._async_failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm._async_failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_failure_callbacks, + global_callbacks=litellm._async_failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created for callback in callbacks: @@ -2953,3 +2911,11 @@ def modify_integration(integration_name, integration_params): if integration_name == "supabase": if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] + + +def get_combined_callback_list( + dynamic_success_callbacks: Optional[List], global_callbacks: List +) -> List: + if dynamic_success_callbacks is None: + return global_callbacks + return list(set(dynamic_success_callbacks + global_callbacks)) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 38883fa38..2c373772a 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -216,3 +216,75 @@ async def test_init_custom_logger_compatible_class_as_callback(): await use_callback_in_llm_call(callback, used_in="success_callback") reset_env_vars() + + +def test_dynamic_logging_global_callback(): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message, Usage + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + kwargs={ + "langfuse_public_key": "my-mock-public-key", + "langfuse_secret_key": "my-mock-secret-key", + }, + dynamic_success_callbacks=["langfuse"], + ) + + with patch.object(cl, "log_success_event") as mock_log_success_event: + cl.log_success_event = mock_log_success_event + litellm.success_callback = [cl] + + try: + litellm_logging.success_handler( + result=ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + system_fingerprint=None, + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + usage=Usage( + completion_tokens=20, + prompt_tokens=10, + total_tokens=30, + completion_tokens_details=None, + prompt_tokens_details=None, + ), + ), + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + except Exception as e: + print(f"Error: {e}") + + mock_log_success_event.assert_called_once() + + +def test_get_combined_callback_list(): + from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list + + assert get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) == ["langfuse", "lago"]