Add recursion depth to convert_anyof_null_to_nullable, constants.py. Fix recursive_detector.py raise error state

This commit is contained in:
Nicholas Grabar 2025-03-28 13:11:19 -07:00
parent e7181395ff
commit 1f2bbda11d
7 changed files with 62 additions and 17 deletions

View file

@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512 DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2 DEFAULT_MAX_RETRIES = 2
DEFAULT_MAX_RECURSE_DEPTH = 10
DEFAULT_FAILURE_THRESHOLD_PERCENT = ( DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute 0.5 # default cooldown a deployment if 50% of requests fail in a given minute
) )

View file

@ -1,8 +1,9 @@
import json import json
from typing import Any, Union from typing import Any, Union
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
def safe_dumps(data: Any, max_depth: int = 10) -> str: def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
""" """
Recursively serialize data while detecting circular references. Recursively serialize data while detecting circular references.
If a circular reference is detected then a marker string is returned. If a circular reference is detected then a marker string is returned.

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
class SensitiveDataMasker: class SensitiveDataMasker:
@ -39,7 +40,7 @@ class SensitiveDataMasker:
return result return result
def mask_dict( def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10 self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if depth >= max_depth: if depth >= max_depth:
return data return data

View file

@ -5,6 +5,7 @@ import httpx
import litellm import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType
@ -228,7 +229,11 @@ def unpack_defs(schema, defs):
continue continue
def convert_anyof_null_to_nullable(schema): def convert_anyof_null_to_nullable(schema, depth=0):
if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise ValueError(
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
)
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """ """ Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
anyof = schema.get("anyOf", None) anyof = schema.get("anyOf", None)
if anyof is not None: if anyof is not None:
@ -256,11 +261,11 @@ def convert_anyof_null_to_nullable(schema):
properties = schema.get("properties", None) properties = schema.get("properties", None)
if properties is not None: if properties is not None:
for name, value in properties.items(): for name, value in properties.items():
convert_anyof_null_to_nullable(value) convert_anyof_null_to_nullable(value, depth=depth + 1)
items = schema.get("items", None) items = schema.get("items", None)
if items is not None: if items is not None:
convert_anyof_null_to_nullable(items) convert_anyof_null_to_nullable(items, depth=depth + 1)
def add_object_type(schema): def add_object_type(schema):

View file

@ -29,6 +29,7 @@ from litellm.types.utils import (
ModelResponseStream, ModelResponseStream,
TextCompletionResponse, TextCompletionResponse,
) )
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -1522,7 +1523,7 @@ class ProxyConfig:
yaml.dump(new_config, config_file, default_flow_style=False) yaml.dump(new_config, config_file, default_flow_style=False)
def _check_for_os_environ_vars( def _check_for_os_environ_vars(
self, config: dict, depth: int = 0, max_depth: int = 10 self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> dict: ) -> dict:
""" """
Check for os.environ/ variables in the config and replace them with the actual values. Check for os.environ/ variables in the config and replace them with the actual values.

View file

@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
"_check_for_os_environ_vars", "_check_for_os_environ_vars",
"clean_message", "clean_message",
"unpack_defs", "unpack_defs",
"convert_to_nullable", "convert_anyof_null_to_nullable", # has a set max depth
"add_object_type", "add_object_type",
"strip_field", "strip_field",
"_transform_prompt", "_transform_prompt",
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
ignored_recursive_functions[file_path] = ignored ignored_recursive_functions[file_path] = ignored
return recursive_functions, ignored_recursive_functions return recursive_functions, ignored_recursive_functions
if __name__ == "__main__":
# Example usage
# raise exception if any recursive functions are found, except for the ignored ones
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
# Example usage directory_path = "./litellm"
directory_path = "./litellm" recursive_functions, ignored_recursive_functions = (
recursive_functions, ignored_recursive_functions = ( find_recursive_functions_in_directory(directory_path)
find_recursive_functions_in_directory(directory_path)
)
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions)
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
if len(recursive_functions) > 0:
raise Exception(
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
) )
print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
if len(recursive_functions) > 0:
# raise exception if any recursive functions are found
for file, functions in recursive_functions.items():
print(
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)
file, functions = list(recursive_functions.items())[0]
raise Exception(
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)

View file

@ -1,6 +1,7 @@
import os import os
import sys import sys
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
import pytest import pytest
@ -112,6 +113,30 @@ def test_nested_anyof_conversion():
} }
assert schema == expected assert schema == expected
def test_anyof_with_excessive_nesting():
"""Test conversion with excessive nesting > max levels +1 deep."""
# generate a schema with excessive nesting
schema = {"type": "object", "properties": {}}
current = schema
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
current["properties"] = {
"nested": {
"anyOf": [
{"type": "string"},
{"type": "null"}
],
"properties": {}
}
}
current = current["properties"]["nested"]
# running the conversion will raise an error
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
convert_anyof_null_to_nullable(schema)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_supports_system_message(): async def test_get_supports_system_message():
"""Test get_supports_system_message with different models""" """Test get_supports_system_message with different models"""