mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(Refactor / QA) - Use LoggingCallbackManager
to append callbacks and ensure no duplicate callbacks are added (#8112)
* LoggingCallbackManager * add logging_callback_manager * use logging_callback_manager * add add_litellm_failure_callback * use add_litellm_callback * use add_litellm_async_success_callback * add_litellm_async_failure_callback * linting fix * fix logging callback manager * test_duplicate_multiple_loggers_test * use _reset_all_callbacks * fix testing with dup callbacks * test_basic_image_generation * reset callbacks for tests * fix check for _add_custom_logger_to_list * fix test_amazing_sync_embedding * fix _get_custom_logger_key * fix batches testing * fix _reset_all_callbacks * fix _check_callback_list_size * add callback_manager_test * fix test gemini-2.0-flash-thinking-exp-01-21
This commit is contained in:
parent
3eac1634fa
commit
8a235e7d38
19 changed files with 607 additions and 59 deletions
|
@ -994,6 +994,7 @@ jobs:
|
|||
- run: ruff check ./litellm
|
||||
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||
- run: python ./tests/code_coverage_tests/callback_manager_test.py
|
||||
- run: python ./tests/code_coverage_tests/recursive_detector.py
|
||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||
|
|
|
@ -38,6 +38,7 @@ from litellm.proxy._types import (
|
|||
)
|
||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
||||
import httpx
|
||||
import dotenv
|
||||
from enum import Enum
|
||||
|
@ -50,6 +51,7 @@ if set_verbose == True:
|
|||
_turn_on_debug()
|
||||
###############################################
|
||||
### Callbacks /Logging / Success / Failure Handlers #####
|
||||
logging_callback_manager = LoggingCallbackManager()
|
||||
input_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||
success_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||
failure_callback: List[Union[str, Callable, CustomLogger]] = []
|
||||
|
|
|
@ -207,9 +207,9 @@ class Cache:
|
|||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.success_callback.append("cache")
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
|
@ -774,9 +774,9 @@ def enable_cache(
|
|||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.success_callback.append("cache")
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
|
|
207
litellm/litellm_core_utils/logging_callback_manager.py
Normal file
207
litellm/litellm_core_utils/logging_callback_manager.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
from typing import Callable, List, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class LoggingCallbackManager:
|
||||
"""
|
||||
A centralized class that allows easy add / remove callbacks for litellm.
|
||||
|
||||
Goals of this class:
|
||||
- Prevent adding duplicate callbacks / success_callback / failure_callback
|
||||
- Keep a reasonable MAX_CALLBACKS limit (this ensures callbacks don't exponentially grow and consume CPU Resources)
|
||||
"""
|
||||
|
||||
# healthy maximum number of callbacks - unlikely someone needs more than 20
|
||||
MAX_CALLBACKS = 30
|
||||
|
||||
def add_litellm_input_callback(self, callback: Union[CustomLogger, str]):
|
||||
"""
|
||||
Add a input callback to litellm.input_callback
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm.input_callback
|
||||
)
|
||||
|
||||
def add_litellm_service_callback(
|
||||
self, callback: Union[CustomLogger, str, Callable]
|
||||
):
|
||||
"""
|
||||
Add a service callback to litellm.service_callback
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm.service_callback
|
||||
)
|
||||
|
||||
def add_litellm_callback(self, callback: Union[CustomLogger, str, Callable]):
|
||||
"""
|
||||
Add a callback to litellm.callbacks
|
||||
|
||||
Ensures no duplicates are added.
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm.callbacks # type: ignore
|
||||
)
|
||||
|
||||
def add_litellm_success_callback(
|
||||
self, callback: Union[CustomLogger, str, Callable]
|
||||
):
|
||||
"""
|
||||
Add a success callback to `litellm.success_callback`
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm.success_callback
|
||||
)
|
||||
|
||||
def add_litellm_failure_callback(
|
||||
self, callback: Union[CustomLogger, str, Callable]
|
||||
):
|
||||
"""
|
||||
Add a failure callback to `litellm.failure_callback`
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm.failure_callback
|
||||
)
|
||||
|
||||
def add_litellm_async_success_callback(
|
||||
self, callback: Union[CustomLogger, Callable, str]
|
||||
):
|
||||
"""
|
||||
Add a success callback to litellm._async_success_callback
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm._async_success_callback
|
||||
)
|
||||
|
||||
def add_litellm_async_failure_callback(
|
||||
self, callback: Union[CustomLogger, Callable, str]
|
||||
):
|
||||
"""
|
||||
Add a failure callback to litellm._async_failure_callback
|
||||
"""
|
||||
self._safe_add_callback_to_list(
|
||||
callback=callback, parent_list=litellm._async_failure_callback
|
||||
)
|
||||
|
||||
def _add_string_callback_to_list(
|
||||
self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
):
|
||||
"""
|
||||
Add a string callback to a list, if the callback is already in the list, do not add it again.
|
||||
"""
|
||||
if callback not in parent_list:
|
||||
parent_list.append(callback)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Callback {callback} already exists in {parent_list}, not adding again.."
|
||||
)
|
||||
|
||||
def _check_callback_list_size(
|
||||
self, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if adding another callback would exceed MAX_CALLBACKS
|
||||
Returns True if safe to add, False if would exceed limit
|
||||
"""
|
||||
if len(parent_list) >= self.MAX_CALLBACKS:
|
||||
verbose_logger.warning(
|
||||
f"Cannot add callback - would exceed MAX_CALLBACKS limit of {self.MAX_CALLBACKS}. Current callbacks: {len(parent_list)}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _safe_add_callback_to_list(
|
||||
self,
|
||||
callback: Union[CustomLogger, Callable, str],
|
||||
parent_list: List[Union[CustomLogger, Callable, str]],
|
||||
):
|
||||
"""
|
||||
Safe add a callback to a list, if the callback is already in the list, do not add it again.
|
||||
|
||||
Ensures no duplicates are added for `str`, `Callable`, and `CustomLogger` callbacks.
|
||||
"""
|
||||
# Check max callbacks limit first
|
||||
if not self._check_callback_list_size(parent_list):
|
||||
return
|
||||
|
||||
if isinstance(callback, str):
|
||||
self._add_string_callback_to_list(
|
||||
callback=callback, parent_list=parent_list
|
||||
)
|
||||
elif isinstance(callback, CustomLogger):
|
||||
self._add_custom_logger_to_list(
|
||||
custom_logger=callback,
|
||||
parent_list=parent_list,
|
||||
)
|
||||
elif callable(callback):
|
||||
self._add_callback_function_to_list(
|
||||
callback=callback, parent_list=parent_list
|
||||
)
|
||||
|
||||
def _add_callback_function_to_list(
|
||||
self, callback: Callable, parent_list: List[Union[CustomLogger, Callable, str]]
|
||||
):
|
||||
"""
|
||||
Add a callback function to a list, if the callback is already in the list, do not add it again.
|
||||
"""
|
||||
# Check if the function already exists in the list by comparing function objects
|
||||
if callback not in parent_list:
|
||||
parent_list.append(callback)
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Callback function {callback.__name__} already exists in {parent_list}, not adding again.."
|
||||
)
|
||||
|
||||
def _add_custom_logger_to_list(
|
||||
self,
|
||||
custom_logger: CustomLogger,
|
||||
parent_list: List[Union[CustomLogger, Callable, str]],
|
||||
):
|
||||
"""
|
||||
Add a custom logger to a list, if another instance of the same custom logger exists in the list, do not add it again.
|
||||
"""
|
||||
# Check if an instance of the same class already exists in the list
|
||||
custom_logger_key = self._get_custom_logger_key(custom_logger)
|
||||
custom_logger_type_name = type(custom_logger).__name__
|
||||
for existing_logger in parent_list:
|
||||
if (
|
||||
isinstance(existing_logger, CustomLogger)
|
||||
and self._get_custom_logger_key(existing_logger) == custom_logger_key
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"Custom logger of type {custom_logger_type_name}, key: {custom_logger_key} already exists in {parent_list}, not adding again.."
|
||||
)
|
||||
return
|
||||
parent_list.append(custom_logger)
|
||||
|
||||
def _get_custom_logger_key(self, custom_logger: CustomLogger):
|
||||
"""
|
||||
Get a unique key for a custom logger that considers only fundamental instance variables
|
||||
|
||||
Returns:
|
||||
str: A unique key combining the class name and fundamental instance variables (str, bool, int)
|
||||
"""
|
||||
key_parts = [type(custom_logger).__name__]
|
||||
|
||||
# Add only fundamental type instance variables to the key
|
||||
for attr_name, attr_value in vars(custom_logger).items():
|
||||
if not attr_name.startswith("_"): # Skip private attributes
|
||||
if isinstance(attr_value, (str, bool, int)):
|
||||
key_parts.append(f"{attr_name}={attr_value}")
|
||||
|
||||
return "-".join(key_parts)
|
||||
|
||||
def _reset_all_callbacks(self):
|
||||
"""
|
||||
Reset all callbacks to an empty list
|
||||
|
||||
Note: this is an internal function and should be used sparingly.
|
||||
"""
|
||||
litellm.input_callback = []
|
||||
litellm.success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
litellm.callbacks = []
|
|
@ -13,7 +13,7 @@ def initialize_aporia(litellm_params, guardrail):
|
|||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_aporia_callback)
|
||||
|
||||
|
||||
def initialize_bedrock(litellm_params, guardrail):
|
||||
|
@ -28,7 +28,7 @@ def initialize_bedrock(litellm_params, guardrail):
|
|||
guardrailVersion=litellm_params["guardrailVersion"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_bedrock_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_bedrock_callback)
|
||||
|
||||
|
||||
def initialize_lakera(litellm_params, guardrail):
|
||||
|
@ -42,7 +42,7 @@ def initialize_lakera(litellm_params, guardrail):
|
|||
category_thresholds=litellm_params.get("category_thresholds"),
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_lakera_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_lakera_callback)
|
||||
|
||||
|
||||
def initialize_aim(litellm_params, guardrail):
|
||||
|
@ -55,7 +55,7 @@ def initialize_aim(litellm_params, guardrail):
|
|||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_aim_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_aim_callback)
|
||||
|
||||
|
||||
def initialize_presidio(litellm_params, guardrail):
|
||||
|
@ -71,7 +71,7 @@ def initialize_presidio(litellm_params, guardrail):
|
|||
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_presidio_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_presidio_callback)
|
||||
|
||||
if litellm_params["output_parse_pii"]:
|
||||
_success_callback = _OPTIONAL_PresidioPIIMasking(
|
||||
|
@ -81,7 +81,7 @@ def initialize_presidio(litellm_params, guardrail):
|
|||
presidio_ad_hoc_recognizers=litellm_params["presidio_ad_hoc_recognizers"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_success_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_success_callback)
|
||||
|
||||
|
||||
def initialize_hide_secrets(litellm_params, guardrail):
|
||||
|
@ -93,7 +93,7 @@ def initialize_hide_secrets(litellm_params, guardrail):
|
|||
guardrail_name=guardrail["guardrail_name"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_secret_detection_object)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_secret_detection_object)
|
||||
|
||||
|
||||
def initialize_guardrails_ai(litellm_params, guardrail):
|
||||
|
@ -111,4 +111,4 @@ def initialize_guardrails_ai(litellm_params, guardrail):
|
|||
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_guardrails_ai_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_guardrails_ai_callback)
|
||||
|
|
|
@ -157,7 +157,7 @@ def init_guardrails_v2(
|
|||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.callbacks.append(_guardrail_callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||
|
||||
|
|
|
@ -736,7 +736,7 @@ user_api_key_cache = DualCache(
|
|||
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||
dual_cache=user_api_key_cache
|
||||
)
|
||||
litellm.callbacks.append(model_max_budget_limiter)
|
||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||
redis_usage_cache: Optional[RedisCache] = (
|
||||
None # redis cache used for tracking spend, tpm/rpm limits
|
||||
)
|
||||
|
@ -934,7 +934,7 @@ def cost_tracking():
|
|||
if isinstance(litellm._async_success_callback, list):
|
||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||
if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
|
||||
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(_PROXY_track_cost_callback) # type: ignore
|
||||
|
||||
|
||||
def error_tracking():
|
||||
|
@ -943,7 +943,7 @@ def error_tracking():
|
|||
if isinstance(litellm.failure_callback, list):
|
||||
verbose_proxy_logger.debug("setting litellm failure callback to track cost")
|
||||
if (_PROXY_failure_handler) not in litellm.failure_callback: # type: ignore
|
||||
litellm.failure_callback.append(_PROXY_failure_handler) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(_PROXY_failure_handler) # type: ignore
|
||||
|
||||
|
||||
def _set_spend_logs_payload(
|
||||
|
@ -1890,12 +1890,14 @@ class ProxyConfig:
|
|||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.success_callback.append(
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
callback
|
||||
)
|
||||
if "prometheus" in callback:
|
||||
if not premium_user:
|
||||
raise Exception(
|
||||
|
@ -1919,12 +1921,14 @@ class ProxyConfig:
|
|||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.failure_callback.append(
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.failure_callback.append(callback)
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||
callback
|
||||
)
|
||||
print( # noqa
|
||||
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
|
||||
) # noqa
|
||||
|
@ -2215,7 +2219,7 @@ class ProxyConfig:
|
|||
},
|
||||
)
|
||||
if _logger is not None:
|
||||
litellm.callbacks.append(_logger)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_logger)
|
||||
pass
|
||||
|
||||
def initialize_secret_manager(self, key_management_system: Optional[str]):
|
||||
|
@ -2497,7 +2501,9 @@ class ProxyConfig:
|
|||
success_callback, "success"
|
||||
)
|
||||
elif success_callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(success_callback)
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
success_callback
|
||||
)
|
||||
|
||||
# Add failure callbacks from DB to litellm
|
||||
if failure_callbacks is not None and isinstance(failure_callbacks, list):
|
||||
|
@ -2510,7 +2516,9 @@ class ProxyConfig:
|
|||
failure_callback, "failure"
|
||||
)
|
||||
elif failure_callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(failure_callback)
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||
failure_callback
|
||||
)
|
||||
|
||||
def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
|
||||
"""
|
||||
|
|
|
@ -323,8 +323,8 @@ class ProxyLogging:
|
|||
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||
# We should NOT add callbacks when alerting is off
|
||||
if "daily_reports" in self.alert_types:
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
|
||||
|
@ -332,10 +332,10 @@ class ProxyLogging:
|
|||
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||||
|
||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
|
@ -348,13 +348,13 @@ class ProxyLogging:
|
|||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||||
if callback not in litellm.service_callback:
|
||||
litellm.service_callback.append(callback) # type: ignore
|
||||
|
||||
|
|
|
@ -483,15 +483,21 @@ class Router:
|
|||
self.access_groups = None
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm._async_success_callback, list):
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||
self.deployment_callback_on_success
|
||||
)
|
||||
else:
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||
self.deployment_callback_on_success
|
||||
)
|
||||
if isinstance(litellm.success_callback, list):
|
||||
litellm.success_callback.append(self.sync_deployment_callback_on_success)
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
self.sync_deployment_callback_on_success
|
||||
)
|
||||
else:
|
||||
litellm.success_callback = [self.sync_deployment_callback_on_success]
|
||||
if isinstance(litellm._async_failure_callback, list):
|
||||
litellm._async_failure_callback.append(
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||
self.async_deployment_callback_on_failure
|
||||
)
|
||||
else:
|
||||
|
@ -500,7 +506,9 @@ class Router:
|
|||
]
|
||||
## COOLDOWNS ##
|
||||
if isinstance(litellm.failure_callback, list):
|
||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||
self.deployment_callback_on_failure
|
||||
)
|
||||
else:
|
||||
litellm.failure_callback = [self.deployment_callback_on_failure]
|
||||
verbose_router_logger.debug(
|
||||
|
@ -606,7 +614,7 @@ class Router:
|
|||
model_list=self.model_list,
|
||||
)
|
||||
if _callback is not None:
|
||||
litellm.callbacks.append(_callback)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_callback)
|
||||
|
||||
def routing_strategy_init(
|
||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||
|
@ -625,7 +633,7 @@ class Router:
|
|||
else:
|
||||
litellm.input_callback = [self.leastbusy_logger] # type: ignore
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.leastbusy_logger) # type: ignore
|
||||
elif (
|
||||
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value
|
||||
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING
|
||||
|
@ -636,7 +644,7 @@ class Router:
|
|||
routing_args=routing_strategy_args,
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger) # type: ignore
|
||||
elif (
|
||||
routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value
|
||||
or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2
|
||||
|
@ -647,7 +655,7 @@ class Router:
|
|||
routing_args=routing_strategy_args,
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.lowesttpm_logger_v2) # type: ignore
|
||||
elif (
|
||||
routing_strategy == RoutingStrategy.LATENCY_BASED.value
|
||||
or routing_strategy == RoutingStrategy.LATENCY_BASED
|
||||
|
@ -658,7 +666,7 @@ class Router:
|
|||
routing_args=routing_strategy_args,
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.lowestlatency_logger) # type: ignore
|
||||
elif (
|
||||
routing_strategy == RoutingStrategy.COST_BASED.value
|
||||
or routing_strategy == RoutingStrategy.COST_BASED
|
||||
|
@ -669,7 +677,7 @@ class Router:
|
|||
routing_args={},
|
||||
)
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.lowestcost_logger) # type: ignore
|
||||
else:
|
||||
pass
|
||||
|
||||
|
@ -5835,8 +5843,8 @@ class Router:
|
|||
|
||||
self.slack_alerting_logger = _slack_alerting_logger
|
||||
|
||||
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
litellm.logging_callback_manager.add_litellm_callback(_slack_alerting_logger) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
|
|
|
@ -64,7 +64,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
# Add self to litellm callbacks if it's a list
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.callbacks.append(self) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
|
|
|
@ -352,8 +352,12 @@ def _add_custom_logger_callback_to_specific_event(
|
|||
and _custom_logger_class_exists_in_success_callbacks(callback_class)
|
||||
is False
|
||||
):
|
||||
litellm.success_callback.append(callback_class)
|
||||
litellm._async_success_callback.append(callback_class)
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(
|
||||
callback_class
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||
callback_class
|
||||
)
|
||||
if callback in litellm.success_callback:
|
||||
litellm.success_callback.remove(
|
||||
callback
|
||||
|
@ -367,8 +371,12 @@ def _add_custom_logger_callback_to_specific_event(
|
|||
and _custom_logger_class_exists_in_failure_callbacks(callback_class)
|
||||
is False
|
||||
):
|
||||
litellm.failure_callback.append(callback_class)
|
||||
litellm._async_failure_callback.append(callback_class)
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(
|
||||
callback_class
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||
callback_class
|
||||
)
|
||||
if callback in litellm.failure_callback:
|
||||
litellm.failure_callback.remove(
|
||||
callback
|
||||
|
@ -447,13 +455,13 @@ def function_setup( # noqa: PLR0915
|
|||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback) # type: ignore
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||||
print_verbose(
|
||||
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
||||
)
|
||||
|
@ -488,12 +496,16 @@ def function_setup( # noqa: PLR0915
|
|||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.success_callback): # type: ignore
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_success_callback.append(callback)
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||
callback
|
||||
)
|
||||
removed_async_items.append(index)
|
||||
elif callback == "dynamodb" or callback == "openmeter":
|
||||
# dynamo is an async callback, it's used for the proxy and needs to be async
|
||||
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
||||
litellm._async_success_callback.append(callback)
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(
|
||||
callback
|
||||
)
|
||||
removed_async_items.append(index)
|
||||
elif (
|
||||
callback in litellm._known_custom_logger_compatible_callbacks
|
||||
|
@ -509,7 +521,9 @@ def function_setup( # noqa: PLR0915
|
|||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.failure_callback): # type: ignore
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_failure_callback.append(callback)
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(
|
||||
callback
|
||||
)
|
||||
removed_async_items.append(index)
|
||||
elif (
|
||||
callback in litellm._known_custom_logger_compatible_callbacks
|
||||
|
|
|
@ -179,8 +179,9 @@ async def test_async_create_batch(provider):
|
|||
2. Create Batch Request
|
||||
3. Retrieve the specific batch
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
print("Testing async create batch")
|
||||
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
|
||||
|
|
113
tests/code_coverage_tests/callback_manager_test.py
Normal file
113
tests/code_coverage_tests/callback_manager_test.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import ast
|
||||
import os
|
||||
|
||||
ALLOWED_FILES = [
|
||||
# local files
|
||||
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
||||
"../../litellm/proxy/common_utils/callback_utils.py",
|
||||
# when running on ci/cd
|
||||
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
||||
"./litellm/proxy/common_utils/callback_utils.py",
|
||||
]
|
||||
|
||||
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
|
||||
|
||||
|
||||
def check_for_callback_modifications(file_path):
|
||||
"""
|
||||
Checks if any direct modifications to specific litellm callback lists are made in the given file.
|
||||
Also prints the violating line of code.
|
||||
"""
|
||||
print("..checking file=", file_path)
|
||||
if file_path in ALLOWED_FILES:
|
||||
return []
|
||||
|
||||
violations = []
|
||||
with open(file_path, "r") as file:
|
||||
try:
|
||||
lines = file.readlines()
|
||||
tree = ast.parse("".join(lines))
|
||||
except SyntaxError:
|
||||
print(f"Warning: Syntax error in file {file_path}")
|
||||
return violations
|
||||
|
||||
protected_lists = [
|
||||
"callbacks",
|
||||
"success_callback",
|
||||
"failure_callback",
|
||||
"_async_success_callback",
|
||||
"_async_failure_callback",
|
||||
]
|
||||
|
||||
forbidden_operations = ["append", "extend", "insert"]
|
||||
|
||||
for node in ast.walk(tree):
|
||||
# Check for attribute calls like litellm.callbacks.append()
|
||||
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
|
||||
# Get the full attribute chain
|
||||
attr_chain = []
|
||||
current = node.func
|
||||
while isinstance(current, ast.Attribute):
|
||||
attr_chain.append(current.attr)
|
||||
current = current.value
|
||||
if isinstance(current, ast.Name):
|
||||
attr_chain.append(current.id)
|
||||
|
||||
# Reverse to get the chain from root to leaf
|
||||
attr_chain = attr_chain[::-1]
|
||||
|
||||
# Check if the attribute chain starts with 'litellm' and modifies a protected list
|
||||
if (
|
||||
len(attr_chain) >= 3
|
||||
and attr_chain[0] == "litellm"
|
||||
and attr_chain[2] in forbidden_operations
|
||||
):
|
||||
protected_list = attr_chain[1]
|
||||
operation = attr_chain[2]
|
||||
if (
|
||||
protected_list in protected_lists
|
||||
and operation in forbidden_operations
|
||||
):
|
||||
violating_line = lines[node.lineno - 1].strip()
|
||||
violations.append(
|
||||
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
|
||||
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
|
||||
f"Please use LoggingCallbackManager instead. {warning_msg}"
|
||||
)
|
||||
|
||||
return violations
|
||||
|
||||
|
||||
def scan_directory_for_callback_modifications(base_dir):
|
||||
"""
|
||||
Scans all Python files in the directory tree for unauthorized callback list modifications.
|
||||
"""
|
||||
all_violations = []
|
||||
for root, _, files in os.walk(base_dir):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
violations = check_for_callback_modifications(file_path)
|
||||
all_violations.extend(violations)
|
||||
return all_violations
|
||||
|
||||
|
||||
def test_no_unauthorized_callback_modifications():
|
||||
"""
|
||||
Test to ensure callback lists are not modified directly anywhere in the codebase.
|
||||
"""
|
||||
base_dir = "./litellm" # Adjust this path as needed
|
||||
# base_dir = "../../litellm" # LOCAL TESTING
|
||||
|
||||
violations = scan_directory_for_callback_modifications(base_dir)
|
||||
if violations:
|
||||
print(f"\nFound {len(violations)} callback modification violations:")
|
||||
for violation in violations:
|
||||
print("\n" + violation)
|
||||
raise AssertionError(
|
||||
"Found unauthorized callback modifications. See above for details."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_no_unauthorized_callback_modifications()
|
|
@ -48,6 +48,7 @@ class BaseImageGenTest(ABC):
|
|||
"""Test basic image generation"""
|
||||
try:
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [custom_logger]
|
||||
base_image_generation_call_args = self.get_base_image_generation_call_args()
|
||||
litellm.set_verbose = True
|
||||
|
|
179
tests/litellm_utils_tests/test_logging_callback_manager.py
Normal file
179
tests/litellm_utils_tests/test_logging_callback_manager.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
||||
from litellm.integrations.langfuse.langfuse_prompt_management import (
|
||||
LangfusePromptManagement,
|
||||
)
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def callback_manager():
|
||||
manager = LoggingCallbackManager()
|
||||
# Reset callbacks before each test
|
||||
manager._reset_all_callbacks()
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_logger():
|
||||
class TestLogger(CustomLogger):
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
return TestLogger()
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_add_string_callback():
|
||||
"""
|
||||
Test adding a string callback to litellm.callbacks - only 1 instance of the string callback should be added
|
||||
"""
|
||||
manager = LoggingCallbackManager()
|
||||
test_callback = "test_callback"
|
||||
|
||||
# Add string callback
|
||||
manager.add_litellm_callback(test_callback)
|
||||
assert test_callback in litellm.callbacks
|
||||
|
||||
# Test duplicate prevention
|
||||
manager.add_litellm_callback(test_callback)
|
||||
assert litellm.callbacks.count(test_callback) == 1
|
||||
|
||||
|
||||
def test_duplicate_langfuse_logger_test():
|
||||
manager = LoggingCallbackManager()
|
||||
for _ in range(10):
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
manager.add_litellm_success_callback(langfuse_logger)
|
||||
print("litellm.success_callback: ", litellm.success_callback)
|
||||
assert len(litellm.success_callback) == 1
|
||||
|
||||
|
||||
def test_duplicate_multiple_loggers_test():
|
||||
manager = LoggingCallbackManager()
|
||||
for _ in range(10):
|
||||
langfuse_logger = LangfusePromptManagement()
|
||||
otel_logger = OpenTelemetry()
|
||||
manager.add_litellm_success_callback(langfuse_logger)
|
||||
manager.add_litellm_success_callback(otel_logger)
|
||||
print("litellm.success_callback: ", litellm.success_callback)
|
||||
assert len(litellm.success_callback) == 2
|
||||
|
||||
# Check exactly one instance of each logger type
|
||||
langfuse_count = sum(
|
||||
1
|
||||
for callback in litellm.success_callback
|
||||
if isinstance(callback, LangfusePromptManagement)
|
||||
)
|
||||
otel_count = sum(
|
||||
1
|
||||
for callback in litellm.success_callback
|
||||
if isinstance(callback, OpenTelemetry)
|
||||
)
|
||||
|
||||
assert (
|
||||
langfuse_count == 1
|
||||
), "Should have exactly one LangfusePromptManagement instance"
|
||||
assert otel_count == 1, "Should have exactly one OpenTelemetry instance"
|
||||
|
||||
|
||||
def test_add_function_callback():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
def test_func(kwargs):
|
||||
pass
|
||||
|
||||
# Add function callback
|
||||
manager.add_litellm_callback(test_func)
|
||||
assert test_func in litellm.callbacks
|
||||
|
||||
# Test duplicate prevention
|
||||
manager.add_litellm_callback(test_func)
|
||||
assert litellm.callbacks.count(test_func) == 1
|
||||
|
||||
|
||||
def test_add_custom_logger(mock_custom_logger):
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
# Add custom logger
|
||||
manager.add_litellm_callback(mock_custom_logger)
|
||||
assert mock_custom_logger in litellm.callbacks
|
||||
|
||||
|
||||
def test_add_multiple_callback_types(mock_custom_logger):
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
def test_func(kwargs):
|
||||
pass
|
||||
|
||||
string_callback = "test_callback"
|
||||
|
||||
# Add different types of callbacks
|
||||
manager.add_litellm_callback(string_callback)
|
||||
manager.add_litellm_callback(test_func)
|
||||
manager.add_litellm_callback(mock_custom_logger)
|
||||
|
||||
assert string_callback in litellm.callbacks
|
||||
assert test_func in litellm.callbacks
|
||||
assert mock_custom_logger in litellm.callbacks
|
||||
assert len(litellm.callbacks) == 3
|
||||
|
||||
|
||||
def test_success_failure_callbacks():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
success_callback = "success_callback"
|
||||
failure_callback = "failure_callback"
|
||||
|
||||
# Add callbacks
|
||||
manager.add_litellm_success_callback(success_callback)
|
||||
manager.add_litellm_failure_callback(failure_callback)
|
||||
|
||||
assert success_callback in litellm.success_callback
|
||||
assert failure_callback in litellm.failure_callback
|
||||
|
||||
|
||||
def test_async_callbacks():
|
||||
manager = LoggingCallbackManager()
|
||||
|
||||
async_success = "async_success"
|
||||
async_failure = "async_failure"
|
||||
|
||||
# Add async callbacks
|
||||
manager.add_litellm_async_success_callback(async_success)
|
||||
manager.add_litellm_async_failure_callback(async_failure)
|
||||
|
||||
assert async_success in litellm._async_success_callback
|
||||
assert async_failure in litellm._async_failure_callback
|
||||
|
||||
|
||||
def test_reset_callbacks(callback_manager):
|
||||
# Add various callbacks
|
||||
callback_manager.add_litellm_callback("test")
|
||||
callback_manager.add_litellm_success_callback("success")
|
||||
callback_manager.add_litellm_failure_callback("failure")
|
||||
callback_manager.add_litellm_async_success_callback("async_success")
|
||||
callback_manager.add_litellm_async_failure_callback("async_failure")
|
||||
|
||||
# Reset all callbacks
|
||||
callback_manager._reset_all_callbacks()
|
||||
|
||||
# Verify all callback lists are empty
|
||||
assert len(litellm.callbacks) == 0
|
||||
assert len(litellm.success_callback) == 0
|
||||
assert len(litellm.failure_callback) == 0
|
||||
assert len(litellm._async_success_callback) == 0
|
||||
assert len(litellm._async_failure_callback) == 0
|
|
@ -846,6 +846,7 @@ async def test_async_embedding_openai():
|
|||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
|
@ -882,6 +883,7 @@ def test_amazing_sync_embedding():
|
|||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = litellm.embedding(
|
||||
|
@ -916,6 +918,7 @@ async def test_async_embedding_azure():
|
|||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
|
@ -956,6 +959,7 @@ async def test_async_embedding_bedrock():
|
|||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
|
@ -1123,6 +1127,7 @@ def test_image_generation_openai():
|
|||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = litellm.image_generation(
|
||||
|
|
|
@ -415,6 +415,8 @@ async def test_async_chat_azure():
|
|||
len(customHandler_completion_azure_router.states) == 3
|
||||
) # pre, post, success
|
||||
# streaming
|
||||
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
response = await router2.acompletion(
|
||||
|
@ -445,6 +447,8 @@ async def test_async_chat_azure():
|
|||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
try:
|
||||
|
@ -507,6 +511,7 @@ async def test_async_embedding_azure():
|
|||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
try:
|
||||
|
|
|
@ -261,6 +261,7 @@ def test_azure_completion_stream():
|
|||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_completion():
|
||||
try:
|
||||
litellm._turn_on_debug
|
||||
customHandler_success = MyCustomHandler()
|
||||
customHandler_failure = MyCustomHandler()
|
||||
# success
|
||||
|
@ -284,6 +285,7 @@ async def test_async_custom_handler_completion():
|
|||
== "gpt-3.5-turbo"
|
||||
)
|
||||
# failure
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
|
|
@ -13,6 +13,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from pydantic import BaseModel
|
||||
from litellm import utils, Router
|
||||
|
||||
|
@ -124,6 +125,7 @@ def test_rate_limit(
|
|||
ExpectNoException: Signfies that no other error has happened. A NOP
|
||||
"""
|
||||
# Can send more messages then we're going to; so don't expect a rate limit error
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
args = locals()
|
||||
print(f"args: {args}")
|
||||
expected_exception = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue