Add recursion depth to convert_anyof_null_to_nullable, constants.py. Fix recursive_detector.py raise error state

This commit is contained in:
Nicholas Grabar 2025-03-28 13:11:19 -07:00
parent e7181395ff
commit 1f2bbda11d
7 changed files with 62 additions and 17 deletions

View file

@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
DEFAULT_MAX_RECURSE_DEPTH = 10
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)

View file

@ -1,8 +1,9 @@
import json
from typing import Any, Union
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
def safe_dumps(data: Any, max_depth: int = 10) -> str:
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
"""
Recursively serialize data while detecting circular references.
If a circular reference is detected then a marker string is returned.

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
class SensitiveDataMasker:
@ -39,7 +40,7 @@ class SensitiveDataMasker:
return result
def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> Dict[str, Any]:
if depth >= max_depth:
return data

View file

@ -5,6 +5,7 @@ import httpx
import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
@ -228,7 +229,11 @@ def unpack_defs(schema, defs):
continue
def convert_anyof_null_to_nullable(schema):
def convert_anyof_null_to_nullable(schema, depth=0):
if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise ValueError(
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
)
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
anyof = schema.get("anyOf", None)
if anyof is not None:
@ -256,11 +261,11 @@ def convert_anyof_null_to_nullable(schema):
properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_anyof_null_to_nullable(value)
convert_anyof_null_to_nullable(value, depth=depth + 1)
items = schema.get("items", None)
if items is not None:
convert_anyof_null_to_nullable(items)
convert_anyof_null_to_nullable(items, depth=depth + 1)
def add_object_type(schema):

View file

@ -29,6 +29,7 @@ from litellm.types.utils import (
ModelResponseStream,
TextCompletionResponse,
)
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -1522,7 +1523,7 @@ class ProxyConfig:
yaml.dump(new_config, config_file, default_flow_style=False)
def _check_for_os_environ_vars(
self, config: dict, depth: int = 0, max_depth: int = 10
self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> dict:
"""
Check for os.environ/ variables in the config and replace them with the actual values.

View file

@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
"_check_for_os_environ_vars",
"clean_message",
"unpack_defs",
"convert_to_nullable",
"convert_anyof_null_to_nullable", # has a set max depth
"add_object_type",
"strip_field",
"_transform_prompt",
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
ignored_recursive_functions[file_path] = ignored
return recursive_functions, ignored_recursive_functions
if __name__ == "__main__":
# Example usage
# raise exception if any recursive functions are found, except for the ignored ones
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
# Example usage
directory_path = "./litellm"
recursive_functions, ignored_recursive_functions = (
directory_path = "./litellm"
recursive_functions, ignored_recursive_functions = (
find_recursive_functions_in_directory(directory_path)
)
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions)
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
if len(recursive_functions) > 0:
raise Exception(
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)
print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
if len(recursive_functions) > 0:
# raise exception if any recursive functions are found
for file, functions in recursive_functions.items():
print(
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)
file, functions = list(recursive_functions.items())[0]
raise Exception(
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)

View file

@ -1,6 +1,7 @@
import os
import sys
from unittest.mock import MagicMock, call, patch
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
import pytest
@ -112,6 +113,30 @@ def test_nested_anyof_conversion():
}
assert schema == expected
def test_anyof_with_excessive_nesting():
"""Test conversion with excessive nesting > max levels +1 deep."""
# generate a schema with excessive nesting
schema = {"type": "object", "properties": {}}
current = schema
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
current["properties"] = {
"nested": {
"anyOf": [
{"type": "string"},
{"type": "null"}
],
"properties": {}
}
}
current = current["properties"]["nested"]
# running the conversion will raise an error
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
convert_anyof_null_to_nullable(schema)
@pytest.mark.asyncio
async def test_get_supports_system_message():
"""Test get_supports_system_message with different models"""