mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* LoggingCallbackManager * add logging_callback_manager * use logging_callback_manager * add add_litellm_failure_callback * use add_litellm_callback * use add_litellm_async_success_callback * add_litellm_async_failure_callback * linting fix * fix logging callback manager * test_duplicate_multiple_loggers_test * use _reset_all_callbacks * fix testing with dup callbacks * test_basic_image_generation * reset callbacks for tests * fix check for _add_custom_logger_to_list * fix test_amazing_sync_embedding * fix _get_custom_logger_key * fix batches testing * fix _reset_all_callbacks * fix _check_callback_list_size * add callback_manager_test * fix test gemini-2.0-flash-thinking-exp-01-21
113 lines
4 KiB
Python
113 lines
4 KiB
Python
import ast
|
|
import os
|
|
|
|
ALLOWED_FILES = [
|
|
# local files
|
|
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
|
"../../litellm/proxy/common_utils/callback_utils.py",
|
|
# when running on ci/cd
|
|
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
|
|
"./litellm/proxy/common_utils/callback_utils.py",
|
|
]
|
|
|
|
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
|
|
|
|
|
|
def check_for_callback_modifications(file_path):
|
|
"""
|
|
Checks if any direct modifications to specific litellm callback lists are made in the given file.
|
|
Also prints the violating line of code.
|
|
"""
|
|
print("..checking file=", file_path)
|
|
if file_path in ALLOWED_FILES:
|
|
return []
|
|
|
|
violations = []
|
|
with open(file_path, "r") as file:
|
|
try:
|
|
lines = file.readlines()
|
|
tree = ast.parse("".join(lines))
|
|
except SyntaxError:
|
|
print(f"Warning: Syntax error in file {file_path}")
|
|
return violations
|
|
|
|
protected_lists = [
|
|
"callbacks",
|
|
"success_callback",
|
|
"failure_callback",
|
|
"_async_success_callback",
|
|
"_async_failure_callback",
|
|
]
|
|
|
|
forbidden_operations = ["append", "extend", "insert"]
|
|
|
|
for node in ast.walk(tree):
|
|
# Check for attribute calls like litellm.callbacks.append()
|
|
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
|
|
# Get the full attribute chain
|
|
attr_chain = []
|
|
current = node.func
|
|
while isinstance(current, ast.Attribute):
|
|
attr_chain.append(current.attr)
|
|
current = current.value
|
|
if isinstance(current, ast.Name):
|
|
attr_chain.append(current.id)
|
|
|
|
# Reverse to get the chain from root to leaf
|
|
attr_chain = attr_chain[::-1]
|
|
|
|
# Check if the attribute chain starts with 'litellm' and modifies a protected list
|
|
if (
|
|
len(attr_chain) >= 3
|
|
and attr_chain[0] == "litellm"
|
|
and attr_chain[2] in forbidden_operations
|
|
):
|
|
protected_list = attr_chain[1]
|
|
operation = attr_chain[2]
|
|
if (
|
|
protected_list in protected_lists
|
|
and operation in forbidden_operations
|
|
):
|
|
violating_line = lines[node.lineno - 1].strip()
|
|
violations.append(
|
|
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
|
|
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
|
|
f"Please use LoggingCallbackManager instead. {warning_msg}"
|
|
)
|
|
|
|
return violations
|
|
|
|
|
|
def scan_directory_for_callback_modifications(base_dir):
|
|
"""
|
|
Scans all Python files in the directory tree for unauthorized callback list modifications.
|
|
"""
|
|
all_violations = []
|
|
for root, _, files in os.walk(base_dir):
|
|
for file in files:
|
|
if file.endswith(".py"):
|
|
file_path = os.path.join(root, file)
|
|
violations = check_for_callback_modifications(file_path)
|
|
all_violations.extend(violations)
|
|
return all_violations
|
|
|
|
|
|
def test_no_unauthorized_callback_modifications():
|
|
"""
|
|
Test to ensure callback lists are not modified directly anywhere in the codebase.
|
|
"""
|
|
base_dir = "./litellm" # Adjust this path as needed
|
|
# base_dir = "../../litellm" # LOCAL TESTING
|
|
|
|
violations = scan_directory_for_callback_modifications(base_dir)
|
|
if violations:
|
|
print(f"\nFound {len(violations)} callback modification violations:")
|
|
for violation in violations:
|
|
print("\n" + violation)
|
|
raise AssertionError(
|
|
"Found unauthorized callback modifications. See above for details."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_no_unauthorized_callback_modifications()
|