(Refactor / QA) - Use LoggingCallbackManager to append callbacks and ensure no duplicate callbacks are added (#8112)

* LoggingCallbackManager

* add logging_callback_manager

* use logging_callback_manager

* add add_litellm_failure_callback

* use add_litellm_callback

* use add_litellm_async_success_callback

* add_litellm_async_failure_callback

* linting fix

* fix logging callback manager

* test_duplicate_multiple_loggers_test

* use _reset_all_callbacks

* fix testing with dup callbacks

* test_basic_image_generation

* reset callbacks for tests

* fix check for _add_custom_logger_to_list

* fix test_amazing_sync_embedding

* fix _get_custom_logger_key

* fix batches testing

* fix _reset_all_callbacks

* fix _check_callback_list_size

* add callback_manager_test

* fix test gemini-2.0-flash-thinking-exp-01-21
This commit is contained in:
Ishaan Jaff 2025-01-30 19:35:50 -08:00 committed by GitHub
parent 3eac1634fa
commit 8a235e7d38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 607 additions and 59 deletions

View file

@ -0,0 +1,113 @@
import ast
import os
ALLOWED_FILES = [
# local files
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"../../litellm/proxy/common_utils/callback_utils.py",
# when running on ci/cd
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"./litellm/proxy/common_utils/callback_utils.py",
]
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
def check_for_callback_modifications(file_path):
"""
Checks if any direct modifications to specific litellm callback lists are made in the given file.
Also prints the violating line of code.
"""
print("..checking file=", file_path)
if file_path in ALLOWED_FILES:
return []
violations = []
with open(file_path, "r") as file:
try:
lines = file.readlines()
tree = ast.parse("".join(lines))
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
return violations
protected_lists = [
"callbacks",
"success_callback",
"failure_callback",
"_async_success_callback",
"_async_failure_callback",
]
forbidden_operations = ["append", "extend", "insert"]
for node in ast.walk(tree):
# Check for attribute calls like litellm.callbacks.append()
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
# Get the full attribute chain
attr_chain = []
current = node.func
while isinstance(current, ast.Attribute):
attr_chain.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
attr_chain.append(current.id)
# Reverse to get the chain from root to leaf
attr_chain = attr_chain[::-1]
# Check if the attribute chain starts with 'litellm' and modifies a protected list
if (
len(attr_chain) >= 3
and attr_chain[0] == "litellm"
and attr_chain[2] in forbidden_operations
):
protected_list = attr_chain[1]
operation = attr_chain[2]
if (
protected_list in protected_lists
and operation in forbidden_operations
):
violating_line = lines[node.lineno - 1].strip()
violations.append(
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
f"Please use LoggingCallbackManager instead. {warning_msg}"
)
return violations
def scan_directory_for_callback_modifications(base_dir):
"""
Scans all Python files in the directory tree for unauthorized callback list modifications.
"""
all_violations = []
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
violations = check_for_callback_modifications(file_path)
all_violations.extend(violations)
return all_violations
def test_no_unauthorized_callback_modifications():
"""
Test to ensure callback lists are not modified directly anywhere in the codebase.
"""
base_dir = "./litellm" # Adjust this path as needed
# base_dir = "../../litellm" # LOCAL TESTING
violations = scan_directory_for_callback_modifications(base_dir)
if violations:
print(f"\nFound {len(violations)} callback modification violations:")
for violation in violations:
print("\n" + violation)
raise AssertionError(
"Found unauthorized callback modifications. See above for details."
)
if __name__ == "__main__":
test_no_unauthorized_callback_modifications()