(Refactor / QA) - Use LoggingCallbackManager to append callbacks and ensure no duplicate callbacks are added (#8112)

* LoggingCallbackManager

* add logging_callback_manager

* use logging_callback_manager

* add add_litellm_failure_callback

* use add_litellm_callback

* use add_litellm_async_success_callback

* add_litellm_async_failure_callback

* linting fix

* fix logging callback manager

* test_duplicate_multiple_loggers_test

* use _reset_all_callbacks

* fix testing with dup callbacks

* test_basic_image_generation

* reset callbacks for tests

* fix check for _add_custom_logger_to_list

* fix test_amazing_sync_embedding

* fix _get_custom_logger_key

* fix batches testing

* fix _reset_all_callbacks

* fix _check_callback_list_size

* add callback_manager_test

* fix test gemini-2.0-flash-thinking-exp-01-21
This commit is contained in:
Ishaan Jaff 2025-01-30 19:35:50 -08:00 committed by GitHub
parent 3eac1634fa
commit 8a235e7d38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 607 additions and 59 deletions

View file

@ -994,6 +994,7 @@ jobs:
- run: ruff check ./litellm
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
- run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/code_coverage_tests/callback_manager_test.py
- run: python ./tests/code_coverage_tests/recursive_detector.py
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py

View file

@ -38,6 +38,7 @@ from litellm.proxy._types import (
)
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
import httpx
import dotenv
from enum import Enum
@ -50,6 +51,7 @@ if set_verbose == True:
_turn_on_debug()
###############################################
### Callbacks /Logging / Success / Failure Handlers #####
logging_callback_manager = LoggingCallbackManager()
input_callback: List[Union[str, Callable, CustomLogger]] = []
success_callback: List[Union[str, Callable, CustomLogger]] = []
failure_callback: List[Union[str, Callable, CustomLogger]] = []

View file

@ -207,9 +207,9 @@ class Cache:
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.success_callback.append("cache")
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache")
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
self.type = type
self.namespace = namespace
@ -774,9 +774,9 @@ def enable_cache(
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.success_callback.append("cache")
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache")
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
if litellm.cache is None:
litellm.cache = Cache(

View file

@ -0,0 +1,207 @@
from typing import Callable, List, Union
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
class LoggingCallbackManager:
"""
A centralized class that allows easy add / remove callbacks for litellm.
Goals of this class:
- Prevent adding duplicate callbacks / success_callback / failure_callback
- Keep a reasonable MAX_CALLBACKS limit (this ensures callbacks don't exponentially grow and consume CPU Resources)
"""
# healthy maximum number of callbacks - unlikely someone needs more than 20
MAX_CALLBACKS = 30
def add_litellm_input_callback(self, callback: Union[CustomLogger, str]):
"""
Add a input callback to litellm.input_callback
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm.input_callback
)
def add_litellm_service_callback(
self, callback: Union[CustomLogger, str, Callable]
):
"""
Add a service callback to litellm.service_callback
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm.service_callback
)
def add_litellm_callback(self, callback: Union[CustomLogger, str, Callable]):
"""
Add a callback to litellm.callbacks
Ensures no duplicates are added.
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm.callbacks # type: ignore
)
def add_litellm_success_callback(
self, callback: Union[CustomLogger, str, Callable]
):
"""
Add a success callback to `litellm.success_callback`
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm.success_callback
)
def add_litellm_failure_callback(
self, callback: Union[CustomLogger, str, Callable]
):
"""
Add a failure callback to `litellm.failure_callback`
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm.failure_callback
)
def add_litellm_async_success_callback(
self, callback: Union[CustomLogger, Callable, str]
):
"""
Add a success callback to litellm._async_success_callback
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm._async_success_callback
)
def add_litellm_async_failure_callback(
self, callback: Union[CustomLogger, Callable, str]
):
"""
Add a failure callback to litellm._async_failure_callback
"""
self._safe_add_callback_to_list(
callback=callback, parent_list=litellm._async_failure_callback
)
def _add_string_callback_to_list(
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
):
"""
Add a string callback to a list, if the callback is already in the list, do not add it again.
"""
if callback not in parent_list:
parent_list.append(callback)
else:
verbose_logger.debug(
f"Callback {callback} already exists in {parent_list}, not adding again.."
)
def _check_callback_list_size(
self, parent_list: List[Union[CustomLogger, Callable, str]]
) -> bool:
"""
Check if adding another callback would exceed MAX_CALLBACKS
Returns True if safe to add, False if would exceed limit
"""
if len(parent_list) >= self.MAX_CALLBACKS:
verbose_logger.warning(
f"Cannot add callback - would exceed MAX_CALLBACKS limit of {self.MAX_CALLBACKS}. Current callbacks: {len(parent_list)}"
)
return False
return True
def _safe_add_callback_to_list(
self,
callback: Union[CustomLogger, Callable, str],
parent_list: List[Union[CustomLogger, Callable, str]],
):
"""
Safe add a callback to a list, if the callback is already in the list, do not add it again.
Ensures no duplicates are added for `str`, `Callable`, and `CustomLogger` callbacks.
"""
# Check max callbacks limit first
if not self._check_callback_list_size(parent_list):
return
if isinstance(callback, str):
self._add_string_callback_to_list(
callback=callback, parent_list=parent_list
)
elif isinstance(callback, CustomLogger):
self._add_custom_logger_to_list(
custom_logger=callback,
parent_list=parent_list,
)
elif callable(callback):
self._add_callback_function_to_list(
callback=callback, parent_list=parent_list
)
def _add_callback_function_to_list(
self, callback: Callable, parent_list: List[Union[CustomLogger, Callable, str]]
):
"""
Add a callback function to a list, if the callback is already in the list, do not add it again.
"""
# Check if the function already exists in the list by comparing function objects
if callback not in parent_list:
parent_list.append(callback)
else:
verbose_logger.debug(
f"Callback function {callback.__name__} already exists in {parent_list}, not adding again.."
)
def _add_custom_logger_to_list(
self,
custom_logger: CustomLogger,
parent_list: List[Union[CustomLogger, Callable, str]],
):
"""
Add a custom logger to a list, if another instance of the same custom logger exists in the list, do not add it again.
"""
# Check if an instance of the same class already exists in the list
custom_logger_key = self._get_custom_logger_key(custom_logger)
custom_logger_type_name = type(custom_logger).__name__
for existing_logger in parent_list:
if (
isinstance(existing_logger, CustomLogger)
and self._get_custom_logger_key(existing_logger) == custom_logger_key
):
verbose_logger.debug(
f"Custom logger of type {custom_logger_type_name}, key: {custom_logger_key} already exists in {parent_list}, not adding again.."
)
return
parent_list.append(custom_logger)
def _get_custom_logger_key(self, custom_logger: CustomLogger):
"""
Get a unique key for a custom logger that considers only fundamental instance variables
Returns:
str: A unique key combining the class name and fundamental instance variables (str, bool, int)
"""
key_parts = [type(custom_logger).__name__]
# Add only fundamental type instance variables to the key
for attr_name, attr_value in vars(custom_logger).items():
if not attr_name.startswith("_"): # Skip private attributes
if isinstance(attr_value, (str, bool, int)):
key_parts.append(f"{attr_name}={attr_value}")
return "-".join(key_parts)
def _reset_all_callbacks(self):
"""
Reset all callbacks to an empty list
Note: this is an internal function and should be used sparingly.
"""
litellm.input_callback = []
litellm.success_callback = []
litellm.failure_callback = []
litellm._async_success_callback = []
litellm._async_failure_callback = []
litellm.callbacks = []

View file

@ -13,7 +13,7 @@ def initialize_aporia(litellm_params, guardrail):
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_aporia_callback)
litellm.logging_callback_manager.add_litellm_callback(_aporia_callback)
def initialize_bedrock(litellm_params, guardrail):
@ -28,7 +28,7 @@ def initialize_bedrock(litellm_params, guardrail):
guardrailVersion=litellm_params["guardrailVersion"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_bedrock_callback)
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
def initialize_lakera(litellm_params, guardrail):
@ -42,7 +42,7 @@ def initialize_lakera(litellm_params, guardrail):
category_thresholds=litellm_params.get("category_thresholds"),
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_lakera_callback)
litellm.logging_callback_manager.add_litellm_callback(_lakera_callback)
def initialize_aim(litellm_params, guardrail):
@ -55,7 +55,7 @@ def initialize_aim(litellm_params, guardrail):
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_aim_callback)
litellm.logging_callback_manager.add_litellm_callback(_aim_callback)
def initialize_presidio(litellm_params, guardrail):
@ -71,7 +71,7 @@ def initialize_presidio(litellm_params, guardrail):
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_presidio_callback)
litellm.logging_callback_manager.add_litellm_callback(_presidio_callback)
if litellm_params["output_parse_pii"]:
_success_callback = _OPTIONAL_PresidioPIIMasking(
@ -81,7 +81,7 @@ def initialize_presidio(litellm_params, guardrail):
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_success_callback)
litellm.logging_callback_manager.add_litellm_callback(_success_callback)
def initialize_hide_secrets(litellm_params, guardrail):
@ -93,7 +93,7 @@ def initialize_hide_secrets(litellm_params, guardrail):
guardrail_name=guardrail["guardrail_name"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_secret_detection_object)
litellm.logging_callback_manager.add_litellm_callback(_secret_detection_object)
def initialize_guardrails_ai(litellm_params, guardrail):
@ -111,4 +111,4 @@ def initialize_guardrails_ai(litellm_params, guardrail):
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_guardrails_ai_callback)
litellm.logging_callback_manager.add_litellm_callback(_guardrails_ai_callback)

View file

@ -157,7 +157,7 @@ def init_guardrails_v2(
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.callbacks.append(_guardrail_callback) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
else:
raise ValueError(f"Unsupported guardrail: {guardrail_type}")

View file

@ -736,7 +736,7 @@ user_api_key_cache = DualCache(
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.callbacks.append(model_max_budget_limiter)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
@ -934,7 +934,7 @@ def cost_tracking():
if isinstance(litellm._async_success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
litellm.logging_callback_manager.add_litellm_async_success_callback(_PROXY_track_cost_callback) # type: ignore
def error_tracking():
@ -943,7 +943,7 @@ def error_tracking():
if isinstance(litellm.failure_callback, list):
verbose_proxy_logger.debug("setting litellm failure callback to track cost")
if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore
litellm.failure_callback.append(_PROXY_failure_handler) # type: ignore
litellm.logging_callback_manager.add_litellm_failure_callback(_PROXY_failure_handler) # type: ignore
def _set_spend_logs_payload(
@ -1890,12 +1890,14 @@ class ProxyConfig:
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.success_callback.append(
litellm.logging_callback_manager.add_litellm_success_callback(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.success_callback.append(callback)
litellm.logging_callback_manager.add_litellm_success_callback(
callback
)
if "prometheus" in callback:
if not premium_user:
raise Exception(
@ -1919,12 +1921,14 @@ class ProxyConfig:
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.failure_callback.append(
litellm.logging_callback_manager.add_litellm_failure_callback(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.failure_callback.append(callback)
litellm.logging_callback_manager.add_litellm_failure_callback(
callback
)
print( # noqa
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
) # noqa
@ -2215,7 +2219,7 @@ class ProxyConfig:
},
)
if _logger is not None:
litellm.callbacks.append(_logger)
litellm.logging_callback_manager.add_litellm_callback(_logger)
pass
def initialize_secret_manager(self, key_management_system: Optional[str]):
@ -2497,7 +2501,9 @@ class ProxyConfig:
success_callback, "success"
)
elif success_callback not in litellm.success_callback:
litellm.success_callback.append(success_callback)
litellm.logging_callback_manager.add_litellm_success_callback(
success_callback
)
# Add failure callbacks from DB to litellm
if failure_callbacks is not None and isinstance(failure_callbacks, list):
@ -2510,7 +2516,9 @@ class ProxyConfig:
failure_callback, "failure"
)
elif failure_callback not in litellm.failure_callback:
litellm.failure_callback.append(failure_callback)
litellm.logging_callback_manager.add_litellm_failure_callback(
failure_callback
)
def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
"""

View file

@ -323,8 +323,8 @@ class ProxyLogging:
# NOTE: ENSURE we only add callbacks when alerting is on
# We should NOT add callbacks when alerting is off
if "daily_reports" in self.alert_types:
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
litellm.success_callback.append(
litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore
litellm.logging_callback_manager.add_litellm_success_callback(
self.slack_alerting_instance.response_taking_too_long_callback
)
@ -332,10 +332,10 @@ class ProxyLogging:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore
litellm.callbacks.append(self.service_logging_obj) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
@ -348,13 +348,13 @@ class ProxyLogging:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback:
litellm.success_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
if callback not in litellm.service_callback:
litellm.service_callback.append(callback) # type: ignore

View file

@ -483,15 +483,21 @@ class Router:
self.access_groups = None
## USAGE TRACKING ##
if isinstance(litellm._async_success_callback, list):
litellm._async_success_callback.append(self.deployment_callback_on_success)
litellm.logging_callback_manager.add_litellm_async_success_callback(
self.deployment_callback_on_success
)
else:
litellm._async_success_callback.append(self.deployment_callback_on_success)
litellm.logging_callback_manager.add_litellm_async_success_callback(
self.deployment_callback_on_success
)
if isinstance(litellm.success_callback, list):
litellm.success_callback.append(self.sync_deployment_callback_on_success)
litellm.logging_callback_manager.add_litellm_success_callback(
self.sync_deployment_callback_on_success
)
else:
litellm.success_callback = [self.sync_deployment_callback_on_success]
if isinstance(litellm._async_failure_callback, list):
litellm._async_failure_callback.append(
litellm.logging_callback_manager.add_litellm_async_failure_callback(
self.async_deployment_callback_on_failure
)
else:
@ -500,7 +506,9 @@ class Router:
]
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
litellm.logging_callback_manager.add_litellm_failure_callback(
self.deployment_callback_on_failure
)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
verbose_router_logger.debug(
@ -606,7 +614,7 @@ class Router:
model_list=self.model_list,
)
if _callback is not None:
litellm.callbacks.append(_callback)
litellm.logging_callback_manager.add_litellm_callback(_callback)
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
@ -625,7 +633,7 @@ class Router:
else:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore
elif (
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
@ -636,7 +644,7 @@ class Router:
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore
elif (
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
@ -647,7 +655,7 @@ class Router:
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore
elif (
routing_strategy == RoutingStrategy.LATENCY_BASED.value
or routing_strategy == RoutingStrategy.LATENCY_BASED
@ -658,7 +666,7 @@ class Router:
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore
elif (
routing_strategy == RoutingStrategy.COST_BASED.value
or routing_strategy == RoutingStrategy.COST_BASED
@ -669,7 +677,7 @@ class Router:
routing_args={},
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore
else:
pass
@ -5835,8 +5843,8 @@ class Router:
self.slack_alerting_logger = _slack_alerting_logger
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
litellm.success_callback.append(
litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore
litellm.logging_callback_manager.add_litellm_success_callback(
_slack_alerting_logger.response_taking_too_long_callback
)
verbose_router_logger.info(

View file

@ -64,7 +64,7 @@ class RouterBudgetLimiting(CustomLogger):
# Add self to litellm callbacks if it's a list
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
async def async_filter_deployments(
self,

View file

@ -352,8 +352,12 @@ def _add_custom_logger_callback_to_specific_event(
and _custom_logger_class_exists_in_success_callbacks(callback_class)
is False
):
litellm.success_callback.append(callback_class)
litellm._async_success_callback.append(callback_class)
litellm.logging_callback_manager.add_litellm_success_callback(
callback_class
)
litellm.logging_callback_manager.add_litellm_async_success_callback(
callback_class
)
if callback in litellm.success_callback:
litellm.success_callback.remove(
callback
@ -367,8 +371,12 @@ def _add_custom_logger_callback_to_specific_event(
and _custom_logger_class_exists_in_failure_callbacks(callback_class)
is False
):
litellm.failure_callback.append(callback_class)
litellm._async_failure_callback.append(callback_class)
litellm.logging_callback_manager.add_litellm_failure_callback(
callback_class
)
litellm.logging_callback_manager.add_litellm_async_failure_callback(
callback_class
)
if callback in litellm.failure_callback:
litellm.failure_callback.remove(
callback
@ -447,13 +455,13 @@ def function_setup( # noqa: PLR0915
if callback not in litellm.input_callback:
litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback:
litellm.success_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) # type: ignore
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
)
@ -488,12 +496,16 @@ def function_setup( # noqa: PLR0915
removed_async_items = []
for index, callback in enumerate(litellm.success_callback): # type: ignore
if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback)
litellm.logging_callback_manager.add_litellm_async_success_callback(
callback
)
removed_async_items.append(index)
elif callback == "dynamodb" or callback == "openmeter":
# dynamo is an async callback, it's used for the proxy and needs to be async
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
litellm._async_success_callback.append(callback)
litellm.logging_callback_manager.add_litellm_async_success_callback(
callback
)
removed_async_items.append(index)
elif (
callback in litellm._known_custom_logger_compatible_callbacks
@ -509,7 +521,9 @@ def function_setup( # noqa: PLR0915
removed_async_items = []
for index, callback in enumerate(litellm.failure_callback): # type: ignore
if inspect.iscoroutinefunction(callback):
litellm._async_failure_callback.append(callback)
litellm.logging_callback_manager.add_litellm_async_failure_callback(
callback
)
removed_async_items.append(index)
elif (
callback in litellm._known_custom_logger_compatible_callbacks

View file

@ -179,8 +179,9 @@ async def test_async_create_batch(provider):
2. Create Batch Request
3. Retrieve the specific batch
"""
litellm._turn_on_debug()
print("Testing async create batch")
litellm.logging_callback_manager._reset_all_callbacks()
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger, "datadog"]

View file

@ -0,0 +1,113 @@
import ast
import os
ALLOWED_FILES = [
# local files
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"../../litellm/proxy/common_utils/callback_utils.py",
# when running on ci/cd
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"./litellm/proxy/common_utils/callback_utils.py",
]
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
def check_for_callback_modifications(file_path):
"""
Checks if any direct modifications to specific litellm callback lists are made in the given file.
Also prints the violating line of code.
"""
print("..checking file=", file_path)
if file_path in ALLOWED_FILES:
return []
violations = []
with open(file_path, "r") as file:
try:
lines = file.readlines()
tree = ast.parse("".join(lines))
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
return violations
protected_lists = [
"callbacks",
"success_callback",
"failure_callback",
"_async_success_callback",
"_async_failure_callback",
]
forbidden_operations = ["append", "extend", "insert"]
for node in ast.walk(tree):
# Check for attribute calls like litellm.callbacks.append()
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
# Get the full attribute chain
attr_chain = []
current = node.func
while isinstance(current, ast.Attribute):
attr_chain.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
attr_chain.append(current.id)
# Reverse to get the chain from root to leaf
attr_chain = attr_chain[::-1]
# Check if the attribute chain starts with 'litellm' and modifies a protected list
if (
len(attr_chain) >= 3
and attr_chain[0] == "litellm"
and attr_chain[2] in forbidden_operations
):
protected_list = attr_chain[1]
operation = attr_chain[2]
if (
protected_list in protected_lists
and operation in forbidden_operations
):
violating_line = lines[node.lineno - 1].strip()
violations.append(
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
f"Please use LoggingCallbackManager instead. {warning_msg}"
)
return violations
def scan_directory_for_callback_modifications(base_dir):
"""
Scans all Python files in the directory tree for unauthorized callback list modifications.
"""
all_violations = []
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
violations = check_for_callback_modifications(file_path)
all_violations.extend(violations)
return all_violations
def test_no_unauthorized_callback_modifications():
"""
Test to ensure callback lists are not modified directly anywhere in the codebase.
"""
base_dir = "./litellm" # Adjust this path as needed
# base_dir = "../../litellm" # LOCAL TESTING
violations = scan_directory_for_callback_modifications(base_dir)
if violations:
print(f"\nFound {len(violations)} callback modification violations:")
for violation in violations:
print("\n" + violation)
raise AssertionError(
"Found unauthorized callback modifications. See above for details."
)
if __name__ == "__main__":
test_no_unauthorized_callback_modifications()

View file

@ -48,6 +48,7 @@ class BaseImageGenTest(ABC):
"""Test basic image generation"""
try:
custom_logger = TestCustomLogger()
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [custom_logger]
base_image_generation_call_args = self.get_base_image_generation_call_args()
litellm.set_verbose = True

View file

@ -0,0 +1,179 @@
import json
import os
import sys
import time
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
from litellm.integrations.langfuse.langfuse_prompt_management import (
LangfusePromptManagement,
)
from litellm.integrations.opentelemetry import OpenTelemetry
# Test fixtures
@pytest.fixture
def callback_manager():
manager = LoggingCallbackManager()
# Reset callbacks before each test
manager._reset_all_callbacks()
return manager
@pytest.fixture
def mock_custom_logger():
class TestLogger(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
return TestLogger()
# Test cases
def test_add_string_callback():
"""
Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added
"""
manager = LoggingCallbackManager()
test_callback = "test_callback"
# Add string callback
manager.add_litellm_callback(test_callback)
assert test_callback in litellm.callbacks
# Test duplicate prevention
manager.add_litellm_callback(test_callback)
assert litellm.callbacks.count(test_callback) == 1
def test_duplicate_langfuse_logger_test():
manager = LoggingCallbackManager()
for _ in range(10):
langfuse_logger = LangfusePromptManagement()
manager.add_litellm_success_callback(langfuse_logger)
print("litellm.success_callback: ", litellm.success_callback)
assert len(litellm.success_callback) == 1
def test_duplicate_multiple_loggers_test():
manager = LoggingCallbackManager()
for _ in range(10):
langfuse_logger = LangfusePromptManagement()
otel_logger = OpenTelemetry()
manager.add_litellm_success_callback(langfuse_logger)
manager.add_litellm_success_callback(otel_logger)
print("litellm.success_callback: ", litellm.success_callback)
assert len(litellm.success_callback) == 2
# Check exactly one instance of each logger type
langfuse_count = sum(
1
for callback in litellm.success_callback
if isinstance(callback, LangfusePromptManagement)
)
otel_count = sum(
1
for callback in litellm.success_callback
if isinstance(callback, OpenTelemetry)
)
assert (
langfuse_count == 1
), "Should have exactly one LangfusePromptManagement instance"
assert otel_count == 1, "Should have exactly one OpenTelemetry instance"
def test_add_function_callback():
manager = LoggingCallbackManager()
def test_func(kwargs):
pass
# Add function callback
manager.add_litellm_callback(test_func)
assert test_func in litellm.callbacks
# Test duplicate prevention
manager.add_litellm_callback(test_func)
assert litellm.callbacks.count(test_func) == 1
def test_add_custom_logger(mock_custom_logger):
manager = LoggingCallbackManager()
# Add custom logger
manager.add_litellm_callback(mock_custom_logger)
assert mock_custom_logger in litellm.callbacks
def test_add_multiple_callback_types(mock_custom_logger):
manager = LoggingCallbackManager()
def test_func(kwargs):
pass
string_callback = "test_callback"
# Add different types of callbacks
manager.add_litellm_callback(string_callback)
manager.add_litellm_callback(test_func)
manager.add_litellm_callback(mock_custom_logger)
assert string_callback in litellm.callbacks
assert test_func in litellm.callbacks
assert mock_custom_logger in litellm.callbacks
assert len(litellm.callbacks) == 3
def test_success_failure_callbacks():
manager = LoggingCallbackManager()
success_callback = "success_callback"
failure_callback = "failure_callback"
# Add callbacks
manager.add_litellm_success_callback(success_callback)
manager.add_litellm_failure_callback(failure_callback)
assert success_callback in litellm.success_callback
assert failure_callback in litellm.failure_callback
def test_async_callbacks():
manager = LoggingCallbackManager()
async_success = "async_success"
async_failure = "async_failure"
# Add async callbacks
manager.add_litellm_async_success_callback(async_success)
manager.add_litellm_async_failure_callback(async_failure)
assert async_success in litellm._async_success_callback
assert async_failure in litellm._async_failure_callback
def test_reset_callbacks(callback_manager):
# Add various callbacks
callback_manager.add_litellm_callback("test")
callback_manager.add_litellm_success_callback("success")
callback_manager.add_litellm_failure_callback("failure")
callback_manager.add_litellm_async_success_callback("async_success")
callback_manager.add_litellm_async_failure_callback("async_failure")
# Reset all callbacks
callback_manager._reset_all_callbacks()
# Verify all callback lists are empty
assert len(litellm.callbacks) == 0
assert len(litellm.success_callback) == 0
assert len(litellm.failure_callback) == 0
assert len(litellm._async_success_callback) == 0
assert len(litellm._async_failure_callback) == 0

View file

@ -846,6 +846,7 @@ async def test_async_embedding_openai():
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
@ -882,6 +883,7 @@ def test_amazing_sync_embedding():
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
try:
response = litellm.embedding(
@ -916,6 +918,7 @@ async def test_async_embedding_azure():
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
@ -956,6 +959,7 @@ async def test_async_embedding_bedrock():
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
@ -1123,6 +1127,7 @@ def test_image_generation_openai():
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
try:
response = litellm.image_generation(

View file

@ -415,6 +415,8 @@ async def test_async_chat_azure():
len(customHandler_completion_azure_router.states) == 3
) # pre, post, success
# streaming
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router2.acompletion(
@ -445,6 +447,8 @@ async def test_async_chat_azure():
"rpm": 1800,
},
]
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:
@ -507,6 +511,7 @@ async def test_async_embedding_azure():
"rpm": 1800,
},
]
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:

View file

@ -261,6 +261,7 @@ def test_azure_completion_stream():
@pytest.mark.asyncio
async def test_async_custom_handler_completion():
try:
litellm._turn_on_debug
customHandler_success = MyCustomHandler()
customHandler_failure = MyCustomHandler()
# success
@ -284,6 +285,7 @@ async def test_async_custom_handler_completion():
== "gpt-3.5-turbo"
)
# failure
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [customHandler_failure]
messages = [
{"role": "system", "content": "You are a helpful assistant."},

View file

@ -13,6 +13,7 @@ sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
from pydantic import BaseModel
from litellm import utils, Router
@ -124,6 +125,7 @@ def test_rate_limit(
ExpectNoException: Signfies that no other error has happened. A NOP
"""
# Can send more messages then we're going to; so don't expect a rate limit error
litellm.logging_callback_manager._reset_all_callbacks()
args = locals()
print(f"args: {args}")
expected_exception = (