mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(Refactor / QA) - Use LoggingCallbackManager
to append callbacks and ensure no duplicate callbacks are added (#8112)
* LoggingCallbackManager * add logging_callback_manager * use logging_callback_manager * add add_litellm_failure_callback * use add_litellm_callback * use add_litellm_async_success_callback * add_litellm_async_failure_callback * linting fix * fix logging callback manager * test_duplicate_multiple_loggers_test * use _reset_all_callbacks * fix testing with dup callbacks * test_basic_image_generation * reset callbacks for tests * fix check for _add_custom_logger_to_list * fix test_amazing_sync_embedding * fix _get_custom_logger_key * fix batches testing * fix _reset_all_callbacks * fix _check_callback_list_size * add callback_manager_test * fix test gemini-2.0-flash-thinking-exp-01-21
This commit is contained in:
parent
11c8d07ed3
commit
fa1c42378f
19 changed files with 607 additions and 59 deletions
|
@ -994,6 +994,7 @@ jobs:
|
||||||
- run: ruff check ./litellm
|
- run: ruff check ./litellm
|
||||||
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
|
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||||
|
- run: python ./tests/code_coverage_tests/callback_manager_test.py
|
||||||
- run: python ./tests/code_coverage_tests/recursive_detector.py
|
- run: python ./tests/code_coverage_tests/recursive_detector.py
|
||||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||||
|
|
|
@ -38,6 +38,7 @@ from litellm.proxy._types import (
|
||||||
)
|
)
|
||||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
||||||
import httpx
|
import httpx
|
||||||
import dotenv
|
import dotenv
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -50,6 +51,7 @@ if set_verbose == True:
|
||||||
_turn_on_debug()
|
_turn_on_debug()
|
||||||
###############################################
|
###############################################
|
||||||
### Callbacks /Logging / Success / Failure Handlers #####
|
### Callbacks /Logging / Success / Failure Handlers #####
|
||||||
|
logging_callback_manager = LoggingCallbackManager()
|
||||||
input_callback: List[Union[str, Callable, CustomLogger]] = []
|
input_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||||
success_callback: List[Union[str, Callable, CustomLogger]] = []
|
success_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||||
failure_callback: List[Union[str, Callable, CustomLogger]] = []
|
failure_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||||
|
|
|
@ -207,9 +207,9 @@ class Cache:
|
||||||
if "cache" not in litellm.input_callback:
|
if "cache" not in litellm.input_callback:
|
||||||
litellm.input_callback.append("cache")
|
litellm.input_callback.append("cache")
|
||||||
if "cache" not in litellm.success_callback:
|
if "cache" not in litellm.success_callback:
|
||||||
litellm.success_callback.append("cache")
|
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||||
if "cache" not in litellm._async_success_callback:
|
if "cache" not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append("cache")
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||||
self.type = type
|
self.type = type
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
|
@ -774,9 +774,9 @@ def enable_cache(
|
||||||
if "cache" not in litellm.input_callback:
|
if "cache" not in litellm.input_callback:
|
||||||
litellm.input_callback.append("cache")
|
litellm.input_callback.append("cache")
|
||||||
if "cache" not in litellm.success_callback:
|
if "cache" not in litellm.success_callback:
|
||||||
litellm.success_callback.append("cache")
|
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||||
if "cache" not in litellm._async_success_callback:
|
if "cache" not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append("cache")
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||||
|
|
||||||
if litellm.cache is None:
|
if litellm.cache is None:
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
|
|
207
litellm/litellm_core_utils/logging_callback_manager.py
Normal file
207
litellm/litellm_core_utils/logging_callback_manager.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCallbackManager:
|
||||||
|
"""
|
||||||
|
A centralized class that allows easy add / remove callbacks for litellm.
|
||||||
|
|
||||||
|
Goals of this class:
|
||||||
|
- Prevent adding duplicate callbacks / success_callback / failure_callback
|
||||||
|
- Keep a reasonable MAX_CALLBACKS limit (this ensures callbacks don't exponentially grow and consume CPU Resources)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# healthy maximum number of callbacks - unlikely someone needs more than 20
|
||||||
|
MAX_CALLBACKS = 30
|
||||||
|
|
||||||
|
def add_litellm_input_callback(self, callback: Union[CustomLogger, str]):
|
||||||
|
"""
|
||||||
|
Add a input callback to litellm.input_callback
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm.input_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_service_callback(
|
||||||
|
self, callback: Union[CustomLogger, str, Callable]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a service callback to litellm.service_callback
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm.service_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_callback(self, callback: Union[CustomLogger, str, Callable]):
|
||||||
|
"""
|
||||||
|
Add a callback to litellm.callbacks
|
||||||
|
|
||||||
|
Ensures no duplicates are added.
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm.callbacks # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_success_callback(
|
||||||
|
self, callback: Union[CustomLogger, str, Callable]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a success callback to `litellm.success_callback`
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm.success_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_failure_callback(
|
||||||
|
self, callback: Union[CustomLogger, str, Callable]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a failure callback to `litellm.failure_callback`
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm.failure_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_async_success_callback(
|
||||||
|
self, callback: Union[CustomLogger, Callable, str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a success callback to litellm._async_success_callback
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm._async_success_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_litellm_async_failure_callback(
|
||||||
|
self, callback: Union[CustomLogger, Callable, str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a failure callback to litellm._async_failure_callback
|
||||||
|
"""
|
||||||
|
self._safe_add_callback_to_list(
|
||||||
|
callback=callback, parent_list=litellm._async_failure_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_string_callback_to_list(
|
||||||
|
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a string callback to a list, if the callback is already in the list, do not add it again.
|
||||||
|
"""
|
||||||
|
if callback not in parent_list:
|
||||||
|
parent_list.append(callback)
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Callback {callback} already exists in {parent_list}, not adding again.."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_callback_list_size(
|
||||||
|
self, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if adding another callback would exceed MAX_CALLBACKS
|
||||||
|
Returns True if safe to add, False if would exceed limit
|
||||||
|
"""
|
||||||
|
if len(parent_list) >= self.MAX_CALLBACKS:
|
||||||
|
verbose_logger.warning(
|
||||||
|
f"Cannot add callback - would exceed MAX_CALLBACKS limit of {self.MAX_CALLBACKS}. Current callbacks: {len(parent_list)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _safe_add_callback_to_list(
|
||||||
|
self,
|
||||||
|
callback: Union[CustomLogger, Callable, str],
|
||||||
|
parent_list: List[Union[CustomLogger, Callable, str]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Safe add a callback to a list, if the callback is already in the list, do not add it again.
|
||||||
|
|
||||||
|
Ensures no duplicates are added for `str`, `Callable`, and `CustomLogger` callbacks.
|
||||||
|
"""
|
||||||
|
# Check max callbacks limit first
|
||||||
|
if not self._check_callback_list_size(parent_list):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(callback, str):
|
||||||
|
self._add_string_callback_to_list(
|
||||||
|
callback=callback, parent_list=parent_list
|
||||||
|
)
|
||||||
|
elif isinstance(callback, CustomLogger):
|
||||||
|
self._add_custom_logger_to_list(
|
||||||
|
custom_logger=callback,
|
||||||
|
parent_list=parent_list,
|
||||||
|
)
|
||||||
|
elif callable(callback):
|
||||||
|
self._add_callback_function_to_list(
|
||||||
|
callback=callback, parent_list=parent_list
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_callback_function_to_list(
|
||||||
|
self, callback: Callable, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a callback function to a list, if the callback is already in the list, do not add it again.
|
||||||
|
"""
|
||||||
|
# Check if the function already exists in the list by comparing function objects
|
||||||
|
if callback not in parent_list:
|
||||||
|
parent_list.append(callback)
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Callback function {callback.__name__} already exists in {parent_list}, not adding again.."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_custom_logger_to_list(
|
||||||
|
self,
|
||||||
|
custom_logger: CustomLogger,
|
||||||
|
parent_list: List[Union[CustomLogger, Callable, str]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a custom logger to a list, if another instance of the same custom logger exists in the list, do not add it again.
|
||||||
|
"""
|
||||||
|
# Check if an instance of the same class already exists in the list
|
||||||
|
custom_logger_key = self._get_custom_logger_key(custom_logger)
|
||||||
|
custom_logger_type_name = type(custom_logger).__name__
|
||||||
|
for existing_logger in parent_list:
|
||||||
|
if (
|
||||||
|
isinstance(existing_logger, CustomLogger)
|
||||||
|
and self._get_custom_logger_key(existing_logger) == custom_logger_key
|
||||||
|
):
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Custom logger of type {custom_logger_type_name}, key: {custom_logger_key} already exists in {parent_list}, not adding again.."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
parent_list.append(custom_logger)
|
||||||
|
|
||||||
|
def _get_custom_logger_key(self, custom_logger: CustomLogger):
|
||||||
|
"""
|
||||||
|
Get a unique key for a custom logger that considers only fundamental instance variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A unique key combining the class name and fundamental instance variables (str, bool, int)
|
||||||
|
"""
|
||||||
|
key_parts = [type(custom_logger).__name__]
|
||||||
|
|
||||||
|
# Add only fundamental type instance variables to the key
|
||||||
|
for attr_name, attr_value in vars(custom_logger).items():
|
||||||
|
if not attr_name.startswith("_"): # Skip private attributes
|
||||||
|
if isinstance(attr_value, (str, bool, int)):
|
||||||
|
key_parts.append(f"{attr_name}={attr_value}")
|
||||||
|
|
||||||
|
return "-".join(key_parts)
|
||||||
|
|
||||||
|
def _reset_all_callbacks(self):
|
||||||
|
"""
|
||||||
|
Reset all callbacks to an empty list
|
||||||
|
|
||||||
|
Note: this is an internal function and should be used sparingly.
|
||||||
|
"""
|
||||||
|
litellm.input_callback = []
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm.failure_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
litellm._async_failure_callback = []
|
||||||
|
litellm.callbacks = []
|
|
@ -13,7 +13,7 @@ def initialize_aporia(litellm_params, guardrail):
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_aporia_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_aporia_callback)
|
||||||
|
|
||||||
|
|
||||||
def initialize_bedrock(litellm_params, guardrail):
|
def initialize_bedrock(litellm_params, guardrail):
|
||||||
|
@ -28,7 +28,7 @@ def initialize_bedrock(litellm_params, guardrail):
|
||||||
guardrailVersion=litellm_params["guardrailVersion"],
|
guardrailVersion=litellm_params["guardrailVersion"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_bedrock_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
|
||||||
|
|
||||||
|
|
||||||
def initialize_lakera(litellm_params, guardrail):
|
def initialize_lakera(litellm_params, guardrail):
|
||||||
|
@ -42,7 +42,7 @@ def initialize_lakera(litellm_params, guardrail):
|
||||||
category_thresholds=litellm_params.get("category_thresholds"),
|
category_thresholds=litellm_params.get("category_thresholds"),
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_lakera_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_lakera_callback)
|
||||||
|
|
||||||
|
|
||||||
def initialize_aim(litellm_params, guardrail):
|
def initialize_aim(litellm_params, guardrail):
|
||||||
|
@ -55,7 +55,7 @@ def initialize_aim(litellm_params, guardrail):
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_aim_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_aim_callback)
|
||||||
|
|
||||||
|
|
||||||
def initialize_presidio(litellm_params, guardrail):
|
def initialize_presidio(litellm_params, guardrail):
|
||||||
|
@ -71,7 +71,7 @@ def initialize_presidio(litellm_params, guardrail):
|
||||||
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_presidio_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_presidio_callback)
|
||||||
|
|
||||||
if litellm_params["output_parse_pii"]:
|
if litellm_params["output_parse_pii"]:
|
||||||
_success_callback = _OPTIONAL_PresidioPIIMasking(
|
_success_callback = _OPTIONAL_PresidioPIIMasking(
|
||||||
|
@ -81,7 +81,7 @@ def initialize_presidio(litellm_params, guardrail):
|
||||||
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
|
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_success_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_success_callback)
|
||||||
|
|
||||||
|
|
||||||
def initialize_hide_secrets(litellm_params, guardrail):
|
def initialize_hide_secrets(litellm_params, guardrail):
|
||||||
|
@ -93,7 +93,7 @@ def initialize_hide_secrets(litellm_params, guardrail):
|
||||||
guardrail_name=guardrail["guardrail_name"],
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_secret_detection_object)
|
litellm.logging_callback_manager.add_litellm_callback(_secret_detection_object)
|
||||||
|
|
||||||
|
|
||||||
def initialize_guardrails_ai(litellm_params, guardrail):
|
def initialize_guardrails_ai(litellm_params, guardrail):
|
||||||
|
@ -111,4 +111,4 @@ def initialize_guardrails_ai(litellm_params, guardrail):
|
||||||
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
|
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_guardrails_ai_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_guardrails_ai_callback)
|
||||||
|
|
|
@ -157,7 +157,7 @@ def init_guardrails_v2(
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
default_on=litellm_params["default_on"],
|
default_on=litellm_params["default_on"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_guardrail_callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||||
|
|
||||||
|
|
|
@ -736,7 +736,7 @@ user_api_key_cache = DualCache(
|
||||||
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
dual_cache=user_api_key_cache
|
dual_cache=user_api_key_cache
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(model_max_budget_limiter)
|
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||||
redis_usage_cache: Optional[RedisCache] = (
|
redis_usage_cache: Optional[RedisCache] = (
|
||||||
None # redis cache used for tracking spend, tpm/rpm limits
|
None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
)
|
)
|
||||||
|
@ -934,7 +934,7 @@ def cost_tracking():
|
||||||
if isinstance(litellm._async_success_callback, list):
|
if isinstance(litellm._async_success_callback, list):
|
||||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||||
if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
|
if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
|
||||||
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_async_success_callback(_PROXY_track_cost_callback) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def error_tracking():
|
def error_tracking():
|
||||||
|
@ -943,7 +943,7 @@ def error_tracking():
|
||||||
if isinstance(litellm.failure_callback, list):
|
if isinstance(litellm.failure_callback, list):
|
||||||
verbose_proxy_logger.debug("setting litellm failure callback to track cost")
|
verbose_proxy_logger.debug("setting litellm failure callback to track cost")
|
||||||
if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore
|
if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore
|
||||||
litellm.failure_callback.append(_PROXY_failure_handler) # type: ignore
|
litellm.logging_callback_manager.add_litellm_failure_callback(_PROXY_failure_handler) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _set_spend_logs_payload(
|
def _set_spend_logs_payload(
|
||||||
|
@ -1890,12 +1890,14 @@ class ProxyConfig:
|
||||||
for callback in value:
|
for callback in value:
|
||||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||||
if "." in callback:
|
if "." in callback:
|
||||||
litellm.success_callback.append(
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
get_instance_fn(value=callback)
|
get_instance_fn(value=callback)
|
||||||
)
|
)
|
||||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||||
else:
|
else:
|
||||||
litellm.success_callback.append(callback)
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
|
callback
|
||||||
|
)
|
||||||
if "prometheus" in callback:
|
if "prometheus" in callback:
|
||||||
if not premium_user:
|
if not premium_user:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -1919,12 +1921,14 @@ class ProxyConfig:
|
||||||
for callback in value:
|
for callback in value:
|
||||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||||
if "." in callback:
|
if "." in callback:
|
||||||
litellm.failure_callback.append(
|
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||||
get_instance_fn(value=callback)
|
get_instance_fn(value=callback)
|
||||||
)
|
)
|
||||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||||
else:
|
else:
|
||||||
litellm.failure_callback.append(callback)
|
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||||
|
callback
|
||||||
|
)
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
|
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
|
||||||
) # noqa
|
) # noqa
|
||||||
|
@ -2215,7 +2219,7 @@ class ProxyConfig:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if _logger is not None:
|
if _logger is not None:
|
||||||
litellm.callbacks.append(_logger)
|
litellm.logging_callback_manager.add_litellm_callback(_logger)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def initialize_secret_manager(self, key_management_system: Optional[str]):
|
def initialize_secret_manager(self, key_management_system: Optional[str]):
|
||||||
|
@ -2497,7 +2501,9 @@ class ProxyConfig:
|
||||||
success_callback, "success"
|
success_callback, "success"
|
||||||
)
|
)
|
||||||
elif success_callback not in litellm.success_callback:
|
elif success_callback not in litellm.success_callback:
|
||||||
litellm.success_callback.append(success_callback)
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
|
success_callback
|
||||||
|
)
|
||||||
|
|
||||||
# Add failure callbacks from DB to litellm
|
# Add failure callbacks from DB to litellm
|
||||||
if failure_callbacks is not None and isinstance(failure_callbacks, list):
|
if failure_callbacks is not None and isinstance(failure_callbacks, list):
|
||||||
|
@ -2510,7 +2516,9 @@ class ProxyConfig:
|
||||||
failure_callback, "failure"
|
failure_callback, "failure"
|
||||||
)
|
)
|
||||||
elif failure_callback not in litellm.failure_callback:
|
elif failure_callback not in litellm.failure_callback:
|
||||||
litellm.failure_callback.append(failure_callback)
|
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||||
|
failure_callback
|
||||||
|
)
|
||||||
|
|
||||||
def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
|
def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -323,8 +323,8 @@ class ProxyLogging:
|
||||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||||
# We should NOT add callbacks when alerting is off
|
# We should NOT add callbacks when alerting is off
|
||||||
if "daily_reports" in self.alert_types:
|
if "daily_reports" in self.alert_types:
|
||||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore
|
||||||
litellm.success_callback.append(
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
self.slack_alerting_instance.response_taking_too_long_callback
|
self.slack_alerting_instance.response_taking_too_long_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -332,10 +332,10 @@ class ProxyLogging:
|
||||||
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||||||
|
|
||||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
|
||||||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
|
||||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
|
||||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||||
|
@ -348,13 +348,13 @@ class ProxyLogging:
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
litellm.input_callback.append(callback) # type: ignore
|
litellm.input_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm.success_callback:
|
if callback not in litellm.success_callback:
|
||||||
litellm.success_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||||||
if callback not in litellm.failure_callback:
|
if callback not in litellm.failure_callback:
|
||||||
litellm.failure_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||||||
if callback not in litellm._async_success_callback:
|
if callback not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||||||
if callback not in litellm._async_failure_callback:
|
if callback not in litellm._async_failure_callback:
|
||||||
litellm._async_failure_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||||||
if callback not in litellm.service_callback:
|
if callback not in litellm.service_callback:
|
||||||
litellm.service_callback.append(callback) # type: ignore
|
litellm.service_callback.append(callback) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -483,15 +483,21 @@ class Router:
|
||||||
self.access_groups = None
|
self.access_groups = None
|
||||||
## USAGE TRACKING ##
|
## USAGE TRACKING ##
|
||||||
if isinstance(litellm._async_success_callback, list):
|
if isinstance(litellm._async_success_callback, list):
|
||||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||||
|
self.deployment_callback_on_success
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||||
|
self.deployment_callback_on_success
|
||||||
|
)
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
litellm.success_callback.append(self.sync_deployment_callback_on_success)
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
|
self.sync_deployment_callback_on_success
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
litellm.success_callback = [self.sync_deployment_callback_on_success]
|
litellm.success_callback = [self.sync_deployment_callback_on_success]
|
||||||
if isinstance(litellm._async_failure_callback, list):
|
if isinstance(litellm._async_failure_callback, list):
|
||||||
litellm._async_failure_callback.append(
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||||
self.async_deployment_callback_on_failure
|
self.async_deployment_callback_on_failure
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -500,7 +506,9 @@ class Router:
|
||||||
]
|
]
|
||||||
## COOLDOWNS ##
|
## COOLDOWNS ##
|
||||||
if isinstance(litellm.failure_callback, list):
|
if isinstance(litellm.failure_callback, list):
|
||||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||||
|
self.deployment_callback_on_failure
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
litellm.failure_callback = [self.deployment_callback_on_failure]
|
litellm.failure_callback = [self.deployment_callback_on_failure]
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -606,7 +614,7 @@ class Router:
|
||||||
model_list=self.model_list,
|
model_list=self.model_list,
|
||||||
)
|
)
|
||||||
if _callback is not None:
|
if _callback is not None:
|
||||||
litellm.callbacks.append(_callback)
|
litellm.logging_callback_manager.add_litellm_callback(_callback)
|
||||||
|
|
||||||
def routing_strategy_init(
|
def routing_strategy_init(
|
||||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||||
|
@ -625,7 +633,7 @@ class Router:
|
||||||
else:
|
else:
|
||||||
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore
|
||||||
elif (
|
elif (
|
||||||
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
|
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
|
||||||
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
|
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
|
||||||
|
@ -636,7 +644,7 @@ class Router:
|
||||||
routing_args=routing_strategy_args,
|
routing_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore
|
||||||
elif (
|
elif (
|
||||||
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
|
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
|
||||||
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
|
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
|
||||||
|
@ -647,7 +655,7 @@ class Router:
|
||||||
routing_args=routing_strategy_args,
|
routing_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore
|
||||||
elif (
|
elif (
|
||||||
routing_strategy == RoutingStrategy.LATENCY_BASED.value
|
routing_strategy == RoutingStrategy.LATENCY_BASED.value
|
||||||
or routing_strategy == RoutingStrategy.LATENCY_BASED
|
or routing_strategy == RoutingStrategy.LATENCY_BASED
|
||||||
|
@ -658,7 +666,7 @@ class Router:
|
||||||
routing_args=routing_strategy_args,
|
routing_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore
|
||||||
elif (
|
elif (
|
||||||
routing_strategy == RoutingStrategy.COST_BASED.value
|
routing_strategy == RoutingStrategy.COST_BASED.value
|
||||||
or routing_strategy == RoutingStrategy.COST_BASED
|
or routing_strategy == RoutingStrategy.COST_BASED
|
||||||
|
@ -669,7 +677,7 @@ class Router:
|
||||||
routing_args={},
|
routing_args={},
|
||||||
)
|
)
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -5835,8 +5843,8 @@ class Router:
|
||||||
|
|
||||||
self.slack_alerting_logger = _slack_alerting_logger
|
self.slack_alerting_logger = _slack_alerting_logger
|
||||||
|
|
||||||
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore
|
||||||
litellm.success_callback.append(
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
_slack_alerting_logger.response_taking_too_long_callback
|
_slack_alerting_logger.response_taking_too_long_callback
|
||||||
)
|
)
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
|
|
|
@ -64,7 +64,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
# Add self to litellm callbacks if it's a list
|
# Add self to litellm callbacks if it's a list
|
||||||
if isinstance(litellm.callbacks, list):
|
if isinstance(litellm.callbacks, list):
|
||||||
litellm.callbacks.append(self) # type: ignore
|
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
|
||||||
|
|
||||||
async def async_filter_deployments(
|
async def async_filter_deployments(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -352,8 +352,12 @@ def _add_custom_logger_callback_to_specific_event(
|
||||||
and _custom_logger_class_exists_in_success_callbacks(callback_class)
|
and _custom_logger_class_exists_in_success_callbacks(callback_class)
|
||||||
is False
|
is False
|
||||||
):
|
):
|
||||||
litellm.success_callback.append(callback_class)
|
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||||
litellm._async_success_callback.append(callback_class)
|
callback_class
|
||||||
|
)
|
||||||
|
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||||
|
callback_class
|
||||||
|
)
|
||||||
if callback in litellm.success_callback:
|
if callback in litellm.success_callback:
|
||||||
litellm.success_callback.remove(
|
litellm.success_callback.remove(
|
||||||
callback
|
callback
|
||||||
|
@ -367,8 +371,12 @@ def _add_custom_logger_callback_to_specific_event(
|
||||||
and _custom_logger_class_exists_in_failure_callbacks(callback_class)
|
and _custom_logger_class_exists_in_failure_callbacks(callback_class)
|
||||||
is False
|
is False
|
||||||
):
|
):
|
||||||
litellm.failure_callback.append(callback_class)
|
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||||
litellm._async_failure_callback.append(callback_class)
|
callback_class
|
||||||
|
)
|
||||||
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||||
|
callback_class
|
||||||
|
)
|
||||||
if callback in litellm.failure_callback:
|
if callback in litellm.failure_callback:
|
||||||
litellm.failure_callback.remove(
|
litellm.failure_callback.remove(
|
||||||
callback
|
callback
|
||||||
|
@ -447,13 +455,13 @@ def function_setup( # noqa: PLR0915
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
litellm.input_callback.append(callback) # type: ignore
|
litellm.input_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm.success_callback:
|
if callback not in litellm.success_callback:
|
||||||
litellm.success_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||||||
if callback not in litellm.failure_callback:
|
if callback not in litellm.failure_callback:
|
||||||
litellm.failure_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||||||
if callback not in litellm._async_success_callback:
|
if callback not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||||||
if callback not in litellm._async_failure_callback:
|
if callback not in litellm._async_failure_callback:
|
||||||
litellm._async_failure_callback.append(callback) # type: ignore
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
||||||
)
|
)
|
||||||
|
@ -488,12 +496,16 @@ def function_setup( # noqa: PLR0915
|
||||||
removed_async_items = []
|
removed_async_items = []
|
||||||
for index, callback in enumerate(litellm.success_callback): # type: ignore
|
for index, callback in enumerate(litellm.success_callback): # type: ignore
|
||||||
if inspect.iscoroutinefunction(callback):
|
if inspect.iscoroutinefunction(callback):
|
||||||
litellm._async_success_callback.append(callback)
|
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||||
|
callback
|
||||||
|
)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
elif callback == "dynamodb" or callback == "openmeter":
|
elif callback == "dynamodb" or callback == "openmeter":
|
||||||
# dynamo is an async callback, it's used for the proxy and needs to be async
|
# dynamo is an async callback, it's used for the proxy and needs to be async
|
||||||
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
||||||
litellm._async_success_callback.append(callback)
|
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||||
|
callback
|
||||||
|
)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
elif (
|
elif (
|
||||||
callback in litellm._known_custom_logger_compatible_callbacks
|
callback in litellm._known_custom_logger_compatible_callbacks
|
||||||
|
@ -509,7 +521,9 @@ def function_setup( # noqa: PLR0915
|
||||||
removed_async_items = []
|
removed_async_items = []
|
||||||
for index, callback in enumerate(litellm.failure_callback): # type: ignore
|
for index, callback in enumerate(litellm.failure_callback): # type: ignore
|
||||||
if inspect.iscoroutinefunction(callback):
|
if inspect.iscoroutinefunction(callback):
|
||||||
litellm._async_failure_callback.append(callback)
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||||
|
callback
|
||||||
|
)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
elif (
|
elif (
|
||||||
callback in litellm._known_custom_logger_compatible_callbacks
|
callback in litellm._known_custom_logger_compatible_callbacks
|
||||||
|
|
|
@ -179,8 +179,9 @@ async def test_async_create_batch(provider):
|
||||||
2. Create Batch Request
|
2. Create Batch Request
|
||||||
3. Retrieve the specific batch
|
3. Retrieve the specific batch
|
||||||
"""
|
"""
|
||||||
|
litellm._turn_on_debug()
|
||||||
print("Testing async create batch")
|
print("Testing async create batch")
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
custom_logger = TestCustomLogger()
|
custom_logger = TestCustomLogger()
|
||||||
litellm.callbacks = [custom_logger, "datadog"]
|
litellm.callbacks = [custom_logger, "datadog"]
|
||||||
|
|
||||||
|
|
113
tests/code_coverage_tests/callback_manager_test.py
Normal file
113
tests/code_coverage_tests/callback_manager_test.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
|
||||||
|
ALLOWED_FILES = [
|
||||||
|
# local files
|
||||||
|
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
||||||
|
"../../litellm/proxy/common_utils/callback_utils.py",
|
||||||
|
# when running on ci/cd
|
||||||
|
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
||||||
|
"./litellm/proxy/common_utils/callback_utils.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_callback_modifications(file_path):
|
||||||
|
"""
|
||||||
|
Checks if any direct modifications to specific litellm callback lists are made in the given file.
|
||||||
|
Also prints the violating line of code.
|
||||||
|
"""
|
||||||
|
print("..checking file=", file_path)
|
||||||
|
if file_path in ALLOWED_FILES:
|
||||||
|
return []
|
||||||
|
|
||||||
|
violations = []
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
try:
|
||||||
|
lines = file.readlines()
|
||||||
|
tree = ast.parse("".join(lines))
|
||||||
|
except SyntaxError:
|
||||||
|
print(f"Warning: Syntax error in file {file_path}")
|
||||||
|
return violations
|
||||||
|
|
||||||
|
protected_lists = [
|
||||||
|
"callbacks",
|
||||||
|
"success_callback",
|
||||||
|
"failure_callback",
|
||||||
|
"_async_success_callback",
|
||||||
|
"_async_failure_callback",
|
||||||
|
]
|
||||||
|
|
||||||
|
forbidden_operations = ["append", "extend", "insert"]
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
# Check for attribute calls like litellm.callbacks.append()
|
||||||
|
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
|
||||||
|
# Get the full attribute chain
|
||||||
|
attr_chain = []
|
||||||
|
current = node.func
|
||||||
|
while isinstance(current, ast.Attribute):
|
||||||
|
attr_chain.append(current.attr)
|
||||||
|
current = current.value
|
||||||
|
if isinstance(current, ast.Name):
|
||||||
|
attr_chain.append(current.id)
|
||||||
|
|
||||||
|
# Reverse to get the chain from root to leaf
|
||||||
|
attr_chain = attr_chain[::-1]
|
||||||
|
|
||||||
|
# Check if the attribute chain starts with 'litellm' and modifies a protected list
|
||||||
|
if (
|
||||||
|
len(attr_chain) >= 3
|
||||||
|
and attr_chain[0] == "litellm"
|
||||||
|
and attr_chain[2] in forbidden_operations
|
||||||
|
):
|
||||||
|
protected_list = attr_chain[1]
|
||||||
|
operation = attr_chain[2]
|
||||||
|
if (
|
||||||
|
protected_list in protected_lists
|
||||||
|
and operation in forbidden_operations
|
||||||
|
):
|
||||||
|
violating_line = lines[node.lineno - 1].strip()
|
||||||
|
violations.append(
|
||||||
|
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
|
||||||
|
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
|
||||||
|
f"Please use LoggingCallbackManager instead. {warning_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def scan_directory_for_callback_modifications(base_dir):
|
||||||
|
"""
|
||||||
|
Scans all Python files in the directory tree for unauthorized callback list modifications.
|
||||||
|
"""
|
||||||
|
all_violations = []
|
||||||
|
for root, _, files in os.walk(base_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
violations = check_for_callback_modifications(file_path)
|
||||||
|
all_violations.extend(violations)
|
||||||
|
return all_violations
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_unauthorized_callback_modifications():
|
||||||
|
"""
|
||||||
|
Test to ensure callback lists are not modified directly anywhere in the codebase.
|
||||||
|
"""
|
||||||
|
base_dir = "./litellm" # Adjust this path as needed
|
||||||
|
# base_dir = "../../litellm" # LOCAL TESTING
|
||||||
|
|
||||||
|
violations = scan_directory_for_callback_modifications(base_dir)
|
||||||
|
if violations:
|
||||||
|
print(f"\nFound {len(violations)} callback modification violations:")
|
||||||
|
for violation in violations:
|
||||||
|
print("\n" + violation)
|
||||||
|
raise AssertionError(
|
||||||
|
"Found unauthorized callback modifications. See above for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_no_unauthorized_callback_modifications()
|
|
@ -48,6 +48,7 @@ class BaseImageGenTest(ABC):
|
||||||
"""Test basic image generation"""
|
"""Test basic image generation"""
|
||||||
try:
|
try:
|
||||||
custom_logger = TestCustomLogger()
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [custom_logger]
|
litellm.callbacks = [custom_logger]
|
||||||
base_image_generation_call_args = self.get_base_image_generation_call_args()
|
base_image_generation_call_args = self.get_base_image_generation_call_args()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
179
tests/litellm_utils_tests/test_logging_callback_manager.py
Normal file
179
tests/litellm_utils_tests/test_logging_callback_manager.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
||||||
|
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||||
|
LangfusePromptManagement,
|
||||||
|
)
|
||||||
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||||
|
|
||||||
|
|
||||||
|
# Test fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def callback_manager():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
# Reset callbacks before each test
|
||||||
|
manager._reset_all_callbacks()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_custom_logger():
|
||||||
|
class TestLogger(CustomLogger):
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return TestLogger()
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases
|
||||||
|
def test_add_string_callback():
|
||||||
|
"""
|
||||||
|
Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added
|
||||||
|
"""
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
test_callback = "test_callback"
|
||||||
|
|
||||||
|
# Add string callback
|
||||||
|
manager.add_litellm_callback(test_callback)
|
||||||
|
assert test_callback in litellm.callbacks
|
||||||
|
|
||||||
|
# Test duplicate prevention
|
||||||
|
manager.add_litellm_callback(test_callback)
|
||||||
|
assert litellm.callbacks.count(test_callback) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_langfuse_logger_test():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
for _ in range(10):
|
||||||
|
langfuse_logger = LangfusePromptManagement()
|
||||||
|
manager.add_litellm_success_callback(langfuse_logger)
|
||||||
|
print("litellm.success_callback: ", litellm.success_callback)
|
||||||
|
assert len(litellm.success_callback) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_multiple_loggers_test():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
for _ in range(10):
|
||||||
|
langfuse_logger = LangfusePromptManagement()
|
||||||
|
otel_logger = OpenTelemetry()
|
||||||
|
manager.add_litellm_success_callback(langfuse_logger)
|
||||||
|
manager.add_litellm_success_callback(otel_logger)
|
||||||
|
print("litellm.success_callback: ", litellm.success_callback)
|
||||||
|
assert len(litellm.success_callback) == 2
|
||||||
|
|
||||||
|
# Check exactly one instance of each logger type
|
||||||
|
langfuse_count = sum(
|
||||||
|
1
|
||||||
|
for callback in litellm.success_callback
|
||||||
|
if isinstance(callback, LangfusePromptManagement)
|
||||||
|
)
|
||||||
|
otel_count = sum(
|
||||||
|
1
|
||||||
|
for callback in litellm.success_callback
|
||||||
|
if isinstance(callback, OpenTelemetry)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
langfuse_count == 1
|
||||||
|
), "Should have exactly one LangfusePromptManagement instance"
|
||||||
|
assert otel_count == 1, "Should have exactly one OpenTelemetry instance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_function_callback():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
|
||||||
|
def test_func(kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add function callback
|
||||||
|
manager.add_litellm_callback(test_func)
|
||||||
|
assert test_func in litellm.callbacks
|
||||||
|
|
||||||
|
# Test duplicate prevention
|
||||||
|
manager.add_litellm_callback(test_func)
|
||||||
|
assert litellm.callbacks.count(test_func) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_custom_logger(mock_custom_logger):
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
|
||||||
|
# Add custom logger
|
||||||
|
manager.add_litellm_callback(mock_custom_logger)
|
||||||
|
assert mock_custom_logger in litellm.callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_multiple_callback_types(mock_custom_logger):
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
|
||||||
|
def test_func(kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
string_callback = "test_callback"
|
||||||
|
|
||||||
|
# Add different types of callbacks
|
||||||
|
manager.add_litellm_callback(string_callback)
|
||||||
|
manager.add_litellm_callback(test_func)
|
||||||
|
manager.add_litellm_callback(mock_custom_logger)
|
||||||
|
|
||||||
|
assert string_callback in litellm.callbacks
|
||||||
|
assert test_func in litellm.callbacks
|
||||||
|
assert mock_custom_logger in litellm.callbacks
|
||||||
|
assert len(litellm.callbacks) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_success_failure_callbacks():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
|
||||||
|
success_callback = "success_callback"
|
||||||
|
failure_callback = "failure_callback"
|
||||||
|
|
||||||
|
# Add callbacks
|
||||||
|
manager.add_litellm_success_callback(success_callback)
|
||||||
|
manager.add_litellm_failure_callback(failure_callback)
|
||||||
|
|
||||||
|
assert success_callback in litellm.success_callback
|
||||||
|
assert failure_callback in litellm.failure_callback
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_callbacks():
|
||||||
|
manager = LoggingCallbackManager()
|
||||||
|
|
||||||
|
async_success = "async_success"
|
||||||
|
async_failure = "async_failure"
|
||||||
|
|
||||||
|
# Add async callbacks
|
||||||
|
manager.add_litellm_async_success_callback(async_success)
|
||||||
|
manager.add_litellm_async_failure_callback(async_failure)
|
||||||
|
|
||||||
|
assert async_success in litellm._async_success_callback
|
||||||
|
assert async_failure in litellm._async_failure_callback
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_callbacks(callback_manager):
|
||||||
|
# Add various callbacks
|
||||||
|
callback_manager.add_litellm_callback("test")
|
||||||
|
callback_manager.add_litellm_success_callback("success")
|
||||||
|
callback_manager.add_litellm_failure_callback("failure")
|
||||||
|
callback_manager.add_litellm_async_success_callback("async_success")
|
||||||
|
callback_manager.add_litellm_async_failure_callback("async_failure")
|
||||||
|
|
||||||
|
# Reset all callbacks
|
||||||
|
callback_manager._reset_all_callbacks()
|
||||||
|
|
||||||
|
# Verify all callback lists are empty
|
||||||
|
assert len(litellm.callbacks) == 0
|
||||||
|
assert len(litellm.success_callback) == 0
|
||||||
|
assert len(litellm.failure_callback) == 0
|
||||||
|
assert len(litellm._async_success_callback) == 0
|
||||||
|
assert len(litellm._async_failure_callback) == 0
|
|
@ -846,6 +846,7 @@ async def test_async_embedding_openai():
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
|
@ -882,6 +883,7 @@ def test_amazing_sync_embedding():
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
|
@ -916,6 +918,7 @@ async def test_async_embedding_azure():
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
|
@ -956,6 +959,7 @@ async def test_async_embedding_bedrock():
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
|
@ -1123,6 +1127,7 @@ def test_image_generation_openai():
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
|
|
|
@ -415,6 +415,8 @@ async def test_async_chat_azure():
|
||||||
len(customHandler_completion_azure_router.states) == 3
|
len(customHandler_completion_azure_router.states) == 3
|
||||||
) # pre, post, success
|
) # pre, post, success
|
||||||
# streaming
|
# streaming
|
||||||
|
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||||
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
|
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||||
response = await router2.acompletion(
|
response = await router2.acompletion(
|
||||||
|
@ -445,6 +447,8 @@ async def test_async_chat_azure():
|
||||||
"rpm": 1800,
|
"rpm": 1800,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
@ -507,6 +511,7 @@ async def test_async_embedding_azure():
|
||||||
"rpm": 1800,
|
"rpm": 1800,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -261,6 +261,7 @@ def test_azure_completion_stream():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_custom_handler_completion():
|
async def test_async_custom_handler_completion():
|
||||||
try:
|
try:
|
||||||
|
litellm._turn_on_debug
|
||||||
customHandler_success = MyCustomHandler()
|
customHandler_success = MyCustomHandler()
|
||||||
customHandler_failure = MyCustomHandler()
|
customHandler_failure = MyCustomHandler()
|
||||||
# success
|
# success
|
||||||
|
@ -284,6 +285,7 @@ async def test_async_custom_handler_completion():
|
||||||
== "gpt-3.5-turbo"
|
== "gpt-3.5-turbo"
|
||||||
)
|
)
|
||||||
# failure
|
# failure
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
|
|
@ -13,6 +13,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../")
|
0, os.path.abspath("../")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from litellm import utils, Router
|
from litellm import utils, Router
|
||||||
|
|
||||||
|
@ -124,6 +125,7 @@ def test_rate_limit(
|
||||||
ExpectNoException: Signfies that no other error has happened. A NOP
|
ExpectNoException: Signfies that no other error has happened. A NOP
|
||||||
"""
|
"""
|
||||||
# Can send more messages then we're going to; so don't expect a rate limit error
|
# Can send more messages then we're going to; so don't expect a rate limit error
|
||||||
|
litellm.logging_callback_manager._reset_all_callbacks()
|
||||||
args = locals()
|
args = locals()
|
||||||
print(f"args: {args}")
|
print(f"args: {args}")
|
||||||
expected_exception = (
|
expected_exception = (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue