diff --git a/litellm/constants.py b/litellm/constants.py index da66f897c9..d45994f400 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5 DEFAULT_BATCH_SIZE = 512 DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_MAX_RETRIES = 2 +DEFAULT_MAX_RECURSE_DEPTH = 10 DEFAULT_FAILURE_THRESHOLD_PERCENT = ( 0.5 # default cooldown a deployment if 50% of requests fail in a given minute ) diff --git a/litellm/litellm_core_utils/safe_json_dumps.py b/litellm/litellm_core_utils/safe_json_dumps.py index 990c0ed561..7ad0038ecb 100644 --- a/litellm/litellm_core_utils/safe_json_dumps.py +++ b/litellm/litellm_core_utils/safe_json_dumps.py @@ -1,8 +1,9 @@ import json from typing import Any, Union +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH -def safe_dumps(data: Any, max_depth: int = 10) -> str: +def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str: """ Recursively serialize data while detecting circular references. If a circular reference is detected then a marker string is returned. diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index a1df115ff0..7800e5304f 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional, Set +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH class SensitiveDataMasker: @@ -39,7 +40,7 @@ class SensitiveDataMasker: return result def mask_dict( - self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10 + self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH ) -> Dict[str, Any]: if depth >= max_depth: return data diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index db2960617c..337445777a 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -5,6 +5,7 @@ import httpx import litellm from litellm import supports_response_schema, supports_system_messages, verbose_logger +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.vertex_ai import PartType @@ -228,7 +229,11 @@ def unpack_defs(schema, defs): continue -def convert_anyof_null_to_nullable(schema): +def convert_anyof_null_to_nullable(schema, depth=0): + if depth > DEFAULT_MAX_RECURSE_DEPTH: + raise ValueError( + f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting." + ) """ Converts null objects within anyOf by removing them and adding nullable to all remaining objects """ anyof = schema.get("anyOf", None) if anyof is not None: @@ -256,11 +261,11 @@ def convert_anyof_null_to_nullable(schema): properties = schema.get("properties", None) if properties is not None: for name, value in properties.items(): - convert_anyof_null_to_nullable(value) + convert_anyof_null_to_nullable(value, depth=depth + 1) items = schema.get("items", None) if items is not None: - convert_anyof_null_to_nullable(items) + convert_anyof_null_to_nullable(items, depth=depth + 1) def add_object_type(schema): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index df295c3697..16f28e5bdc 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -29,6 +29,7 @@ from litellm.types.utils import ( ModelResponseStream, TextCompletionResponse, ) +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -1522,7 +1523,7 @@ class ProxyConfig: yaml.dump(new_config, config_file, default_flow_style=False) def _check_for_os_environ_vars( - self, config: dict, depth: int = 0, max_depth: int = 10 + self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH ) -> dict: """ Check for os.environ/ variables in the config and replace them with the actual values. diff --git a/tests/code_coverage_tests/recursive_detector.py b/tests/code_coverage_tests/recursive_detector.py index 3e9c597b94..b748d1a517 100644 --- a/tests/code_coverage_tests/recursive_detector.py +++ b/tests/code_coverage_tests/recursive_detector.py @@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [ "_check_for_os_environ_vars", "clean_message", "unpack_defs", - "convert_to_nullable", + "convert_anyof_null_to_nullable", # has a set max depth "add_object_type", "strip_field", "_transform_prompt", @@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory): ignored_recursive_functions[file_path] = ignored return recursive_functions, ignored_recursive_functions +if __name__ == "__main__": + # Example usage + # raise exception if any recursive functions are found, except for the ignored ones + # this is used in the CI/CD pipeline to prevent recursive functions from being merged -# Example usage -directory_path = "./litellm" -recursive_functions, ignored_recursive_functions = ( - find_recursive_functions_in_directory(directory_path) -) -print("ALL RECURSIVE FUNCTIONS: ", recursive_functions) -print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions) -if len(recursive_functions) > 0: - raise Exception( - f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + directory_path = "./litellm" + recursive_functions, ignored_recursive_functions = ( + find_recursive_functions_in_directory(directory_path) ) + print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions) + print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions) + + if len(recursive_functions) > 0: + # raise exception if any recursive functions are found + for file, functions in recursive_functions.items(): + print( + f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + ) + file, functions = list(recursive_functions.items())[0] + raise Exception( + f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + ) + diff --git a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py index 0dc4a533fe..9412cc6d81 100644 --- a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py +++ b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py @@ -1,6 +1,7 @@ import os import sys from unittest.mock import MagicMock, call, patch +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH import pytest @@ -112,6 +113,30 @@ def test_nested_anyof_conversion(): } assert schema == expected +def test_anyof_with_excessive_nesting(): + """Test conversion with excessive nesting > max levels +1 deep.""" + # generate a schema with excessive nesting + schema = {"type": "object", "properties": {}} + current = schema + for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1): + current["properties"] = { + "nested": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "properties": {} + } + } + current = current["properties"]["nested"] + + + # running the conversion will raise an error + with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."): + convert_anyof_null_to_nullable(schema) + + + @pytest.mark.asyncio async def test_get_supports_system_message(): """Test get_supports_system_message with different models"""