mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Add recursion depth to convert_anyof_null_to_nullable, constants.py. Fix recursive_detector.py raise error state
This commit is contained in:
parent
e7181395ff
commit
1f2bbda11d
7 changed files with 62 additions and 17 deletions
|
@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
|
|||
DEFAULT_BATCH_SIZE = 512
|
||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||
DEFAULT_MAX_RETRIES = 2
|
||||
DEFAULT_MAX_RECURSE_DEPTH = 10
|
||||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||
)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
from typing import Any, Union
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
|
||||
def safe_dumps(data: Any, max_depth: int = 10) -> str:
|
||||
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||
"""
|
||||
Recursively serialize data while detecting circular references.
|
||||
If a circular reference is detected then a marker string is returned.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Dict, Optional, Set
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
|
||||
class SensitiveDataMasker:
|
||||
|
@ -39,7 +40,7 @@ class SensitiveDataMasker:
|
|||
return result
|
||||
|
||||
def mask_dict(
|
||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||
) -> Dict[str, Any]:
|
||||
if depth >= max_depth:
|
||||
return data
|
||||
|
|
|
@ -5,6 +5,7 @@ import httpx
|
|||
|
||||
import litellm
|
||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
|
||||
|
@ -228,7 +229,11 @@ def unpack_defs(schema, defs):
|
|||
continue
|
||||
|
||||
|
||||
def convert_anyof_null_to_nullable(schema):
|
||||
def convert_anyof_null_to_nullable(schema, depth=0):
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
raise ValueError(
|
||||
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||
)
|
||||
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
|
||||
anyof = schema.get("anyOf", None)
|
||||
if anyof is not None:
|
||||
|
@ -256,11 +261,11 @@ def convert_anyof_null_to_nullable(schema):
|
|||
properties = schema.get("properties", None)
|
||||
if properties is not None:
|
||||
for name, value in properties.items():
|
||||
convert_anyof_null_to_nullable(value)
|
||||
convert_anyof_null_to_nullable(value, depth=depth + 1)
|
||||
|
||||
items = schema.get("items", None)
|
||||
if items is not None:
|
||||
convert_anyof_null_to_nullable(items)
|
||||
convert_anyof_null_to_nullable(items, depth=depth + 1)
|
||||
|
||||
|
||||
def add_object_type(schema):
|
||||
|
|
|
@ -29,6 +29,7 @@ from litellm.types.utils import (
|
|||
ModelResponseStream,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -1522,7 +1523,7 @@ class ProxyConfig:
|
|||
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||
|
||||
def _check_for_os_environ_vars(
|
||||
self, config: dict, depth: int = 0, max_depth: int = 10
|
||||
self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||
) -> dict:
|
||||
"""
|
||||
Check for os.environ/ variables in the config and replace them with the actual values.
|
||||
|
|
|
@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
|
|||
"_check_for_os_environ_vars",
|
||||
"clean_message",
|
||||
"unpack_defs",
|
||||
"convert_to_nullable",
|
||||
"convert_anyof_null_to_nullable", # has a set max depth
|
||||
"add_object_type",
|
||||
"strip_field",
|
||||
"_transform_prompt",
|
||||
|
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
|
|||
ignored_recursive_functions[file_path] = ignored
|
||||
return recursive_functions, ignored_recursive_functions
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
# raise exception if any recursive functions are found, except for the ignored ones
|
||||
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
|
||||
|
||||
# Example usage
|
||||
directory_path = "./litellm"
|
||||
recursive_functions, ignored_recursive_functions = (
|
||||
directory_path = "./litellm"
|
||||
recursive_functions, ignored_recursive_functions = (
|
||||
find_recursive_functions_in_directory(directory_path)
|
||||
)
|
||||
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions)
|
||||
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
||||
if len(recursive_functions) > 0:
|
||||
raise Exception(
|
||||
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||
)
|
||||
print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
|
||||
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
||||
|
||||
if len(recursive_functions) > 0:
|
||||
# raise exception if any recursive functions are found
|
||||
for file, functions in recursive_functions.items():
|
||||
print(
|
||||
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||
)
|
||||
file, functions = list(recursive_functions.items())[0]
|
||||
raise Exception(
|
||||
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -112,6 +113,30 @@ def test_nested_anyof_conversion():
|
|||
}
|
||||
assert schema == expected
|
||||
|
||||
def test_anyof_with_excessive_nesting():
|
||||
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||
# generate a schema with excessive nesting
|
||||
schema = {"type": "object", "properties": {}}
|
||||
current = schema
|
||||
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||
current["properties"] = {
|
||||
"nested": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
],
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
current = current["properties"]["nested"]
|
||||
|
||||
|
||||
# running the conversion will raise an error
|
||||
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||
convert_anyof_null_to_nullable(schema)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_supports_system_message():
|
||||
"""Test get_supports_system_message with different models"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue