diff --git a/.circleci/config.yml b/.circleci/config.yml index 92d869568b..9acffc9625 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -994,6 +994,7 @@ jobs: - run: ruff check ./litellm # - run: python ./tests/documentation_tests/test_general_setting_keys.py - run: python ./tests/code_coverage_tests/router_code_coverage.py + - run: python ./tests/code_coverage_tests/callback_manager_test.py - run: python ./tests/code_coverage_tests/recursive_detector.py - run: python ./tests/code_coverage_tests/test_router_strategy_async.py - run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 9cd5cb8c4a..bccab529ff 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -38,6 +38,7 @@ from litellm.proxy._types import ( ) from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager import httpx import dotenv from enum import Enum @@ -50,6 +51,7 @@ if set_verbose == True: _turn_on_debug() ############################################### ### Callbacks /Logging / Success / Failure Handlers ##### +logging_callback_manager = LoggingCallbackManager() input_callback: List[Union[str, Callable, CustomLogger]] = [] success_callback: List[Union[str, Callable, CustomLogger]] = [] failure_callback: List[Union[str, Callable, CustomLogger]] = [] diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index e50e8b76d6..f7842ad48a 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -207,9 +207,9 @@ class Cache: if "cache" not in litellm.input_callback: litellm.input_callback.append("cache") if "cache" not in litellm.success_callback: - litellm.success_callback.append("cache") + litellm.logging_callback_manager.add_litellm_success_callback("cache") if "cache" not in litellm._async_success_callback: - litellm._async_success_callback.append("cache") + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] self.type = type self.namespace = namespace @@ -774,9 +774,9 @@ def enable_cache( if "cache" not in litellm.input_callback: litellm.input_callback.append("cache") if "cache" not in litellm.success_callback: - litellm.success_callback.append("cache") + litellm.logging_callback_manager.add_litellm_success_callback("cache") if "cache" not in litellm._async_success_callback: - litellm._async_success_callback.append("cache") + litellm.logging_callback_manager.add_litellm_async_success_callback("cache") if litellm.cache is None: litellm.cache = Cache( diff --git a/litellm/litellm_core_utils/logging_callback_manager.py b/litellm/litellm_core_utils/logging_callback_manager.py new file mode 100644 index 0000000000..860a57c5f6 --- /dev/null +++ b/litellm/litellm_core_utils/logging_callback_manager.py @@ -0,0 +1,207 @@ +from typing import Callable, List, Union + +import litellm +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger + + +class LoggingCallbackManager: + """ + A centralized class that allows easy add / remove callbacks for litellm. + + Goals of this class: + - Prevent adding duplicate callbacks / success_callback / failure_callback + - Keep a reasonable MAX_CALLBACKS limit (this ensures callbacks don't exponentially grow and consume CPU Resources) + """ + + # healthy maximum number of callbacks - unlikely someone needs more than 20 + MAX_CALLBACKS = 30 + + def add_litellm_input_callback(self, callback: Union[CustomLogger, str]): + """ + Add a input callback to litellm.input_callback + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm.input_callback + ) + + def add_litellm_service_callback( + self, callback: Union[CustomLogger, str, Callable] + ): + """ + Add a service callback to litellm.service_callback + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm.service_callback + ) + + def add_litellm_callback(self, callback: Union[CustomLogger, str, Callable]): + """ + Add a callback to litellm.callbacks + + Ensures no duplicates are added. + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm.callbacks # type: ignore + ) + + def add_litellm_success_callback( + self, callback: Union[CustomLogger, str, Callable] + ): + """ + Add a success callback to `litellm.success_callback` + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm.success_callback + ) + + def add_litellm_failure_callback( + self, callback: Union[CustomLogger, str, Callable] + ): + """ + Add a failure callback to `litellm.failure_callback` + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm.failure_callback + ) + + def add_litellm_async_success_callback( + self, callback: Union[CustomLogger, Callable, str] + ): + """ + Add a success callback to litellm._async_success_callback + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm._async_success_callback + ) + + def add_litellm_async_failure_callback( + self, callback: Union[CustomLogger, Callable, str] + ): + """ + Add a failure callback to litellm._async_failure_callback + """ + self._safe_add_callback_to_list( + callback=callback, parent_list=litellm._async_failure_callback + ) + + def _add_string_callback_to_list( + self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]] + ): + """ + Add a string callback to a list, if the callback is already in the list, do not add it again. + """ + if callback not in parent_list: + parent_list.append(callback) + else: + verbose_logger.debug( + f"Callback {callback} already exists in {parent_list}, not adding again.." + ) + + def _check_callback_list_size( + self, parent_list: List[Union[CustomLogger, Callable, str]] + ) -> bool: + """ + Check if adding another callback would exceed MAX_CALLBACKS + Returns True if safe to add, False if would exceed limit + """ + if len(parent_list) >= self.MAX_CALLBACKS: + verbose_logger.warning( + f"Cannot add callback - would exceed MAX_CALLBACKS limit of {self.MAX_CALLBACKS}. Current callbacks: {len(parent_list)}" + ) + return False + return True + + def _safe_add_callback_to_list( + self, + callback: Union[CustomLogger, Callable, str], + parent_list: List[Union[CustomLogger, Callable, str]], + ): + """ + Safe add a callback to a list, if the callback is already in the list, do not add it again. + + Ensures no duplicates are added for `str`, `Callable`, and `CustomLogger` callbacks. + """ + # Check max callbacks limit first + if not self._check_callback_list_size(parent_list): + return + + if isinstance(callback, str): + self._add_string_callback_to_list( + callback=callback, parent_list=parent_list + ) + elif isinstance(callback, CustomLogger): + self._add_custom_logger_to_list( + custom_logger=callback, + parent_list=parent_list, + ) + elif callable(callback): + self._add_callback_function_to_list( + callback=callback, parent_list=parent_list + ) + + def _add_callback_function_to_list( + self, callback: Callable, parent_list: List[Union[CustomLogger, Callable, str]] + ): + """ + Add a callback function to a list, if the callback is already in the list, do not add it again. + """ + # Check if the function already exists in the list by comparing function objects + if callback not in parent_list: + parent_list.append(callback) + else: + verbose_logger.debug( + f"Callback function {callback.__name__} already exists in {parent_list}, not adding again.." + ) + + def _add_custom_logger_to_list( + self, + custom_logger: CustomLogger, + parent_list: List[Union[CustomLogger, Callable, str]], + ): + """ + Add a custom logger to a list, if another instance of the same custom logger exists in the list, do not add it again. + """ + # Check if an instance of the same class already exists in the list + custom_logger_key = self._get_custom_logger_key(custom_logger) + custom_logger_type_name = type(custom_logger).__name__ + for existing_logger in parent_list: + if ( + isinstance(existing_logger, CustomLogger) + and self._get_custom_logger_key(existing_logger) == custom_logger_key + ): + verbose_logger.debug( + f"Custom logger of type {custom_logger_type_name}, key: {custom_logger_key} already exists in {parent_list}, not adding again.." + ) + return + parent_list.append(custom_logger) + + def _get_custom_logger_key(self, custom_logger: CustomLogger): + """ + Get a unique key for a custom logger that considers only fundamental instance variables + + Returns: + str: A unique key combining the class name and fundamental instance variables (str, bool, int) + """ + key_parts = [type(custom_logger).__name__] + + # Add only fundamental type instance variables to the key + for attr_name, attr_value in vars(custom_logger).items(): + if not attr_name.startswith("_"): # Skip private attributes + if isinstance(attr_value, (str, bool, int)): + key_parts.append(f"{attr_name}={attr_value}") + + return "-".join(key_parts) + + def _reset_all_callbacks(self): + """ + Reset all callbacks to an empty list + + Note: this is an internal function and should be used sparingly. + """ + litellm.input_callback = [] + litellm.success_callback = [] + litellm.failure_callback = [] + litellm._async_success_callback = [] + litellm._async_failure_callback = [] + litellm.callbacks = [] diff --git a/litellm/proxy/guardrails/guardrail_initializers.py b/litellm/proxy/guardrails/guardrail_initializers.py index 15d4be57cb..c32d75f986 100644 --- a/litellm/proxy/guardrails/guardrail_initializers.py +++ b/litellm/proxy/guardrails/guardrail_initializers.py @@ -13,7 +13,7 @@ def initialize_aporia(litellm_params, guardrail): event_hook=litellm_params["mode"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_aporia_callback) + litellm.logging_callback_manager.add_litellm_callback(_aporia_callback) def initialize_bedrock(litellm_params, guardrail): @@ -28,7 +28,7 @@ def initialize_bedrock(litellm_params, guardrail): guardrailVersion=litellm_params["guardrailVersion"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_bedrock_callback) + litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback) def initialize_lakera(litellm_params, guardrail): @@ -42,7 +42,7 @@ def initialize_lakera(litellm_params, guardrail): category_thresholds=litellm_params.get("category_thresholds"), default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_lakera_callback) + litellm.logging_callback_manager.add_litellm_callback(_lakera_callback) def initialize_aim(litellm_params, guardrail): @@ -55,7 +55,7 @@ def initialize_aim(litellm_params, guardrail): event_hook=litellm_params["mode"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_aim_callback) + litellm.logging_callback_manager.add_litellm_callback(_aim_callback) def initialize_presidio(litellm_params, guardrail): @@ -71,7 +71,7 @@ def initialize_presidio(litellm_params, guardrail): mock_redacted_text=litellm_params.get("mock_redacted_text") or None, default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_presidio_callback) + litellm.logging_callback_manager.add_litellm_callback(_presidio_callback) if litellm_params["output_parse_pii"]: _success_callback = _OPTIONAL_PresidioPIIMasking( @@ -81,7 +81,7 @@ def initialize_presidio(litellm_params, guardrail): presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_success_callback) + litellm.logging_callback_manager.add_litellm_callback(_success_callback) def initialize_hide_secrets(litellm_params, guardrail): @@ -93,7 +93,7 @@ def initialize_hide_secrets(litellm_params, guardrail): guardrail_name=guardrail["guardrail_name"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_secret_detection_object) + litellm.logging_callback_manager.add_litellm_callback(_secret_detection_object) def initialize_guardrails_ai(litellm_params, guardrail): @@ -111,4 +111,4 @@ def initialize_guardrails_ai(litellm_params, guardrail): guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value, default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_guardrails_ai_callback) + litellm.logging_callback_manager.add_litellm_callback(_guardrails_ai_callback) diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 74f859a708..129be655ee 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -157,7 +157,7 @@ def init_guardrails_v2( event_hook=litellm_params["mode"], default_on=litellm_params["default_on"], ) - litellm.callbacks.append(_guardrail_callback) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore else: raise ValueError(f"Unsupported guardrail: {guardrail_type}") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 103204270e..03592f8e7e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -736,7 +736,7 @@ user_api_key_cache = DualCache( model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) -litellm.callbacks.append(model_max_budget_limiter) +litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) redis_usage_cache: Optional[RedisCache] = ( None # redis cache used for tracking spend, tpm/rpm limits ) @@ -934,7 +934,7 @@ def cost_tracking(): if isinstance(litellm._async_success_callback, list): verbose_proxy_logger.debug("setting litellm success callback to track cost") if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore - litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore + litellm.logging_callback_manager.add_litellm_async_success_callback(_PROXY_track_cost_callback) # type: ignore def error_tracking(): @@ -943,7 +943,7 @@ def error_tracking(): if isinstance(litellm.failure_callback, list): verbose_proxy_logger.debug("setting litellm failure callback to track cost") if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore - litellm.failure_callback.append(_PROXY_failure_handler) # type: ignore + litellm.logging_callback_manager.add_litellm_failure_callback(_PROXY_failure_handler) # type: ignore def _set_spend_logs_payload( @@ -1890,12 +1890,14 @@ class ProxyConfig: for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function if "." in callback: - litellm.success_callback.append( + litellm.logging_callback_manager.add_litellm_success_callback( get_instance_fn(value=callback) ) # these are litellm callbacks - "langfuse", "sentry", "wandb" else: - litellm.success_callback.append(callback) + litellm.logging_callback_manager.add_litellm_success_callback( + callback + ) if "prometheus" in callback: if not premium_user: raise Exception( @@ -1919,12 +1921,14 @@ class ProxyConfig: for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function if "." in callback: - litellm.failure_callback.append( + litellm.logging_callback_manager.add_litellm_failure_callback( get_instance_fn(value=callback) ) # these are litellm callbacks - "langfuse", "sentry", "wandb" else: - litellm.failure_callback.append(callback) + litellm.logging_callback_manager.add_litellm_failure_callback( + callback + ) print( # noqa f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}" ) # noqa @@ -2215,7 +2219,7 @@ class ProxyConfig: }, ) if _logger is not None: - litellm.callbacks.append(_logger) + litellm.logging_callback_manager.add_litellm_callback(_logger) pass def initialize_secret_manager(self, key_management_system: Optional[str]): @@ -2497,7 +2501,9 @@ class ProxyConfig: success_callback, "success" ) elif success_callback not in litellm.success_callback: - litellm.success_callback.append(success_callback) + litellm.logging_callback_manager.add_litellm_success_callback( + success_callback + ) # Add failure callbacks from DB to litellm if failure_callbacks is not None and isinstance(failure_callbacks, list): @@ -2510,7 +2516,9 @@ class ProxyConfig: failure_callback, "failure" ) elif failure_callback not in litellm.failure_callback: - litellm.failure_callback.append(failure_callback) + litellm.logging_callback_manager.add_litellm_failure_callback( + failure_callback + ) def _add_environment_variables_from_db_config(self, config_data: dict) -> None: """ diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 409f4bfa5c..934b3da545 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -323,8 +323,8 @@ class ProxyLogging: # NOTE: ENSURE we only add callbacks when alerting is on # We should NOT add callbacks when alerting is off if "daily_reports" in self.alert_types: - litellm.callbacks.append(self.slack_alerting_instance) # type: ignore - litellm.success_callback.append( + litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback( self.slack_alerting_instance.response_taking_too_long_callback ) @@ -332,10 +332,10 @@ class ProxyLogging: self.internal_usage_cache.dual_cache.redis_cache = redis_cache def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): - litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore - litellm.callbacks.append(self.max_budget_limiter) # type: ignore - litellm.callbacks.append(self.cache_control_check) # type: ignore - litellm.callbacks.append(self.service_logging_obj) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore for callback in litellm.callbacks: if isinstance(callback, str): callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore @@ -348,13 +348,13 @@ class ProxyLogging: if callback not in litellm.input_callback: litellm.input_callback.append(callback) # type: ignore if callback not in litellm.success_callback: - litellm.success_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore if callback not in litellm.service_callback: litellm.service_callback.append(callback) # type: ignore diff --git a/litellm/router.py b/litellm/router.py index 63669ef588..fb3250367b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -483,15 +483,21 @@ class Router: self.access_groups = None ## USAGE TRACKING ## if isinstance(litellm._async_success_callback, list): - litellm._async_success_callback.append(self.deployment_callback_on_success) + litellm.logging_callback_manager.add_litellm_async_success_callback( + self.deployment_callback_on_success + ) else: - litellm._async_success_callback.append(self.deployment_callback_on_success) + litellm.logging_callback_manager.add_litellm_async_success_callback( + self.deployment_callback_on_success + ) if isinstance(litellm.success_callback, list): - litellm.success_callback.append(self.sync_deployment_callback_on_success) + litellm.logging_callback_manager.add_litellm_success_callback( + self.sync_deployment_callback_on_success + ) else: litellm.success_callback = [self.sync_deployment_callback_on_success] if isinstance(litellm._async_failure_callback, list): - litellm._async_failure_callback.append( + litellm.logging_callback_manager.add_litellm_async_failure_callback( self.async_deployment_callback_on_failure ) else: @@ -500,7 +506,9 @@ class Router: ] ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): - litellm.failure_callback.append(self.deployment_callback_on_failure) + litellm.logging_callback_manager.add_litellm_failure_callback( + self.deployment_callback_on_failure + ) else: litellm.failure_callback = [self.deployment_callback_on_failure] verbose_router_logger.debug( @@ -606,7 +614,7 @@ class Router: model_list=self.model_list, ) if _callback is not None: - litellm.callbacks.append(_callback) + litellm.logging_callback_manager.add_litellm_callback(_callback) def routing_strategy_init( self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict @@ -625,7 +633,7 @@ class Router: else: litellm.input_callback = [self.leastbusy_logger] # type: ignore if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.leastbusy_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore elif ( routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING @@ -636,7 +644,7 @@ class Router: routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.lowesttpm_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore elif ( routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2 @@ -647,7 +655,7 @@ class Router: routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore elif ( routing_strategy == RoutingStrategy.LATENCY_BASED.value or routing_strategy == RoutingStrategy.LATENCY_BASED @@ -658,7 +666,7 @@ class Router: routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.lowestlatency_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore elif ( routing_strategy == RoutingStrategy.COST_BASED.value or routing_strategy == RoutingStrategy.COST_BASED @@ -669,7 +677,7 @@ class Router: routing_args={}, ) if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self.lowestcost_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore else: pass @@ -5835,8 +5843,8 @@ class Router: self.slack_alerting_logger = _slack_alerting_logger - litellm.callbacks.append(_slack_alerting_logger) # type: ignore - litellm.success_callback.append( + litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback( _slack_alerting_logger.response_taking_too_long_callback ) verbose_router_logger.info( diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 920f6c0881..d584923ed5 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -64,7 +64,7 @@ class RouterBudgetLimiting(CustomLogger): # Add self to litellm callbacks if it's a list if isinstance(litellm.callbacks, list): - litellm.callbacks.append(self) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore async def async_filter_deployments( self, diff --git a/litellm/utils.py b/litellm/utils.py index 48f4ea7ca2..b1e683113a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -352,8 +352,12 @@ def _add_custom_logger_callback_to_specific_event( and _custom_logger_class_exists_in_success_callbacks(callback_class) is False ): - litellm.success_callback.append(callback_class) - litellm._async_success_callback.append(callback_class) + litellm.logging_callback_manager.add_litellm_success_callback( + callback_class + ) + litellm.logging_callback_manager.add_litellm_async_success_callback( + callback_class + ) if callback in litellm.success_callback: litellm.success_callback.remove( callback @@ -367,8 +371,12 @@ def _add_custom_logger_callback_to_specific_event( and _custom_logger_class_exists_in_failure_callbacks(callback_class) is False ): - litellm.failure_callback.append(callback_class) - litellm._async_failure_callback.append(callback_class) + litellm.logging_callback_manager.add_litellm_failure_callback( + callback_class + ) + litellm.logging_callback_manager.add_litellm_async_failure_callback( + callback_class + ) if callback in litellm.failure_callback: litellm.failure_callback.remove( callback @@ -447,13 +455,13 @@ def function_setup( # noqa: PLR0915 if callback not in litellm.input_callback: litellm.input_callback.append(callback) # type: ignore if callback not in litellm.success_callback: - litellm.success_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) # type: ignore + litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore print_verbose( f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" ) @@ -488,12 +496,16 @@ def function_setup( # noqa: PLR0915 removed_async_items = [] for index, callback in enumerate(litellm.success_callback): # type: ignore if inspect.iscoroutinefunction(callback): - litellm._async_success_callback.append(callback) + litellm.logging_callback_manager.add_litellm_async_success_callback( + callback + ) removed_async_items.append(index) elif callback == "dynamodb" or callback == "openmeter": # dynamo is an async callback, it's used for the proxy and needs to be async # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy - litellm._async_success_callback.append(callback) + litellm.logging_callback_manager.add_litellm_async_success_callback( + callback + ) removed_async_items.append(index) elif ( callback in litellm._known_custom_logger_compatible_callbacks @@ -509,7 +521,9 @@ def function_setup( # noqa: PLR0915 removed_async_items = [] for index, callback in enumerate(litellm.failure_callback): # type: ignore if inspect.iscoroutinefunction(callback): - litellm._async_failure_callback.append(callback) + litellm.logging_callback_manager.add_litellm_async_failure_callback( + callback + ) removed_async_items.append(index) elif ( callback in litellm._known_custom_logger_compatible_callbacks diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index 994263f1b8..8465df78d8 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -179,8 +179,9 @@ async def test_async_create_batch(provider): 2. Create Batch Request 3. Retrieve the specific batch """ + litellm._turn_on_debug() print("Testing async create batch") - + litellm.logging_callback_manager._reset_all_callbacks() custom_logger = TestCustomLogger() litellm.callbacks = [custom_logger, "datadog"] diff --git a/tests/code_coverage_tests/callback_manager_test.py b/tests/code_coverage_tests/callback_manager_test.py new file mode 100644 index 0000000000..557a5027c4 --- /dev/null +++ b/tests/code_coverage_tests/callback_manager_test.py @@ -0,0 +1,113 @@ +import ast +import os + +ALLOWED_FILES = [ + # local files + "../../litellm/litellm_core_utils/litellm.logging_callback_manager.py", + "../../litellm/proxy/common_utils/callback_utils.py", + # when running on ci/cd + "./litellm/litellm_core_utils/litellm.logging_callback_manager.py", + "./litellm/proxy/common_utils/callback_utils.py", +] + +warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager" + + +def check_for_callback_modifications(file_path): + """ + Checks if any direct modifications to specific litellm callback lists are made in the given file. + Also prints the violating line of code. + """ + print("..checking file=", file_path) + if file_path in ALLOWED_FILES: + return [] + + violations = [] + with open(file_path, "r") as file: + try: + lines = file.readlines() + tree = ast.parse("".join(lines)) + except SyntaxError: + print(f"Warning: Syntax error in file {file_path}") + return violations + + protected_lists = [ + "callbacks", + "success_callback", + "failure_callback", + "_async_success_callback", + "_async_failure_callback", + ] + + forbidden_operations = ["append", "extend", "insert"] + + for node in ast.walk(tree): + # Check for attribute calls like litellm.callbacks.append() + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + # Get the full attribute chain + attr_chain = [] + current = node.func + while isinstance(current, ast.Attribute): + attr_chain.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + attr_chain.append(current.id) + + # Reverse to get the chain from root to leaf + attr_chain = attr_chain[::-1] + + # Check if the attribute chain starts with 'litellm' and modifies a protected list + if ( + len(attr_chain) >= 3 + and attr_chain[0] == "litellm" + and attr_chain[2] in forbidden_operations + ): + protected_list = attr_chain[1] + operation = attr_chain[2] + if ( + protected_list in protected_lists + and operation in forbidden_operations + ): + violating_line = lines[node.lineno - 1].strip() + violations.append( + f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. " + f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. " + f"Please use LoggingCallbackManager instead. {warning_msg}" + ) + + return violations + + +def scan_directory_for_callback_modifications(base_dir): + """ + Scans all Python files in the directory tree for unauthorized callback list modifications. + """ + all_violations = [] + for root, _, files in os.walk(base_dir): + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + violations = check_for_callback_modifications(file_path) + all_violations.extend(violations) + return all_violations + + +def test_no_unauthorized_callback_modifications(): + """ + Test to ensure callback lists are not modified directly anywhere in the codebase. + """ + base_dir = "./litellm" # Adjust this path as needed + # base_dir = "../../litellm" # LOCAL TESTING + + violations = scan_directory_for_callback_modifications(base_dir) + if violations: + print(f"\nFound {len(violations)} callback modification violations:") + for violation in violations: + print("\n" + violation) + raise AssertionError( + "Found unauthorized callback modifications. See above for details." + ) + + +if __name__ == "__main__": + test_no_unauthorized_callback_modifications() diff --git a/tests/image_gen_tests/base_image_generation_test.py b/tests/image_gen_tests/base_image_generation_test.py index e0652114db..746b0ef713 100644 --- a/tests/image_gen_tests/base_image_generation_test.py +++ b/tests/image_gen_tests/base_image_generation_test.py @@ -48,6 +48,7 @@ class BaseImageGenTest(ABC): """Test basic image generation""" try: custom_logger = TestCustomLogger() + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [custom_logger] base_image_generation_call_args = self.get_base_image_generation_call_args() litellm.set_verbose = True diff --git a/tests/litellm_utils_tests/test_logging_callback_manager.py b/tests/litellm_utils_tests/test_logging_callback_manager.py new file mode 100644 index 0000000000..71ffb18678 --- /dev/null +++ b/tests/litellm_utils_tests/test_logging_callback_manager.py @@ -0,0 +1,179 @@ +import json +import os +import sys +import time +from datetime import datetime +from unittest.mock import AsyncMock, patch, MagicMock +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager +from litellm.integrations.langfuse.langfuse_prompt_management import ( + LangfusePromptManagement, +) +from litellm.integrations.opentelemetry import OpenTelemetry + + +# Test fixtures +@pytest.fixture +def callback_manager(): + manager = LoggingCallbackManager() + # Reset callbacks before each test + manager._reset_all_callbacks() + return manager + + +@pytest.fixture +def mock_custom_logger(): + class TestLogger(CustomLogger): + def log_success_event(self, kwargs, response_obj, start_time, end_time): + pass + + return TestLogger() + + +# Test cases +def test_add_string_callback(): + """ + Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added + """ + manager = LoggingCallbackManager() + test_callback = "test_callback" + + # Add string callback + manager.add_litellm_callback(test_callback) + assert test_callback in litellm.callbacks + + # Test duplicate prevention + manager.add_litellm_callback(test_callback) + assert litellm.callbacks.count(test_callback) == 1 + + +def test_duplicate_langfuse_logger_test(): + manager = LoggingCallbackManager() + for _ in range(10): + langfuse_logger = LangfusePromptManagement() + manager.add_litellm_success_callback(langfuse_logger) + print("litellm.success_callback: ", litellm.success_callback) + assert len(litellm.success_callback) == 1 + + +def test_duplicate_multiple_loggers_test(): + manager = LoggingCallbackManager() + for _ in range(10): + langfuse_logger = LangfusePromptManagement() + otel_logger = OpenTelemetry() + manager.add_litellm_success_callback(langfuse_logger) + manager.add_litellm_success_callback(otel_logger) + print("litellm.success_callback: ", litellm.success_callback) + assert len(litellm.success_callback) == 2 + + # Check exactly one instance of each logger type + langfuse_count = sum( + 1 + for callback in litellm.success_callback + if isinstance(callback, LangfusePromptManagement) + ) + otel_count = sum( + 1 + for callback in litellm.success_callback + if isinstance(callback, OpenTelemetry) + ) + + assert ( + langfuse_count == 1 + ), "Should have exactly one LangfusePromptManagement instance" + assert otel_count == 1, "Should have exactly one OpenTelemetry instance" + + +def test_add_function_callback(): + manager = LoggingCallbackManager() + + def test_func(kwargs): + pass + + # Add function callback + manager.add_litellm_callback(test_func) + assert test_func in litellm.callbacks + + # Test duplicate prevention + manager.add_litellm_callback(test_func) + assert litellm.callbacks.count(test_func) == 1 + + +def test_add_custom_logger(mock_custom_logger): + manager = LoggingCallbackManager() + + # Add custom logger + manager.add_litellm_callback(mock_custom_logger) + assert mock_custom_logger in litellm.callbacks + + +def test_add_multiple_callback_types(mock_custom_logger): + manager = LoggingCallbackManager() + + def test_func(kwargs): + pass + + string_callback = "test_callback" + + # Add different types of callbacks + manager.add_litellm_callback(string_callback) + manager.add_litellm_callback(test_func) + manager.add_litellm_callback(mock_custom_logger) + + assert string_callback in litellm.callbacks + assert test_func in litellm.callbacks + assert mock_custom_logger in litellm.callbacks + assert len(litellm.callbacks) == 3 + + +def test_success_failure_callbacks(): + manager = LoggingCallbackManager() + + success_callback = "success_callback" + failure_callback = "failure_callback" + + # Add callbacks + manager.add_litellm_success_callback(success_callback) + manager.add_litellm_failure_callback(failure_callback) + + assert success_callback in litellm.success_callback + assert failure_callback in litellm.failure_callback + + +def test_async_callbacks(): + manager = LoggingCallbackManager() + + async_success = "async_success" + async_failure = "async_failure" + + # Add async callbacks + manager.add_litellm_async_success_callback(async_success) + manager.add_litellm_async_failure_callback(async_failure) + + assert async_success in litellm._async_success_callback + assert async_failure in litellm._async_failure_callback + + +def test_reset_callbacks(callback_manager): + # Add various callbacks + callback_manager.add_litellm_callback("test") + callback_manager.add_litellm_success_callback("success") + callback_manager.add_litellm_failure_callback("failure") + callback_manager.add_litellm_async_success_callback("async_success") + callback_manager.add_litellm_async_failure_callback("async_failure") + + # Reset all callbacks + callback_manager._reset_all_callbacks() + + # Verify all callback lists are empty + assert len(litellm.callbacks) == 0 + assert len(litellm.success_callback) == 0 + assert len(litellm.failure_callback) == 0 + assert len(litellm._async_success_callback) == 0 + assert len(litellm._async_failure_callback) == 0 diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py index 911defd0b4..9630896a52 100644 --- a/tests/local_testing/test_custom_callback_input.py +++ b/tests/local_testing/test_custom_callback_input.py @@ -846,6 +846,7 @@ async def test_async_embedding_openai(): assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] try: response = await litellm.aembedding( @@ -882,6 +883,7 @@ def test_amazing_sync_embedding(): assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] try: response = litellm.embedding( @@ -916,6 +918,7 @@ async def test_async_embedding_azure(): assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] try: response = await litellm.aembedding( @@ -956,6 +959,7 @@ async def test_async_embedding_bedrock(): assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] try: response = await litellm.aembedding( @@ -1123,6 +1127,7 @@ def test_image_generation_openai(): assert len(customHandler_success.errors) == 0 assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] try: response = litellm.image_generation( diff --git a/tests/local_testing/test_custom_callback_router.py b/tests/local_testing/test_custom_callback_router.py index 3e9ac39eda..2234690101 100644 --- a/tests/local_testing/test_custom_callback_router.py +++ b/tests/local_testing/test_custom_callback_router.py @@ -415,6 +415,8 @@ async def test_async_chat_azure(): len(customHandler_completion_azure_router.states) == 3 ) # pre, post, success # streaming + + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_streaming_azure_router] router2 = Router(model_list=model_list, num_retries=0) # type: ignore response = await router2.acompletion( @@ -445,6 +447,8 @@ async def test_async_chat_azure(): "rpm": 1800, }, ] + + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: @@ -507,6 +511,7 @@ async def test_async_embedding_azure(): "rpm": 1800, }, ] + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: diff --git a/tests/local_testing/test_custom_logger.py b/tests/local_testing/test_custom_logger.py index 4058a9aa03..d9eb50eb73 100644 --- a/tests/local_testing/test_custom_logger.py +++ b/tests/local_testing/test_custom_logger.py @@ -261,6 +261,7 @@ def test_azure_completion_stream(): @pytest.mark.asyncio async def test_async_custom_handler_completion(): try: + litellm._turn_on_debug customHandler_success = MyCustomHandler() customHandler_failure = MyCustomHandler() # success @@ -284,6 +285,7 @@ async def test_async_custom_handler_completion(): == "gpt-3.5-turbo" ) # failure + litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [customHandler_failure] messages = [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py index be662d0c1b..add9deb6d7 100644 --- a/tests/test_ratelimit.py +++ b/tests/test_ratelimit.py @@ -13,6 +13,7 @@ sys.path.insert( 0, os.path.abspath("../") ) # Adds the parent directory to the system path +import litellm from pydantic import BaseModel from litellm import utils, Router @@ -124,6 +125,7 @@ def test_rate_limit( ExpectNoException: Signfies that no other error has happened. A NOP """ # Can send more messages then we're going to; so don't expect a rate limit error + litellm.logging_callback_manager._reset_all_callbacks() args = locals() print(f"args: {args}") expected_exception = (