mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Add recursion depth to convert_anyof_null_to_nullable, constants.py. Fix recursive_detector.py raise error state
This commit is contained in:
parent
e7181395ff
commit
1f2bbda11d
7 changed files with 62 additions and 17 deletions
|
@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
|
DEFAULT_MAX_RECURSE_DEPTH = 10
|
||||||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
|
|
||||||
def safe_dumps(data: Any, max_depth: int = 10) -> str:
|
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||||
"""
|
"""
|
||||||
Recursively serialize data while detecting circular references.
|
Recursively serialize data while detecting circular references.
|
||||||
If a circular reference is detected then a marker string is returned.
|
If a circular reference is detected then a marker string is returned.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Optional, Set
|
from typing import Any, Dict, Optional, Set
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
|
|
||||||
class SensitiveDataMasker:
|
class SensitiveDataMasker:
|
||||||
|
@ -39,7 +40,7 @@ class SensitiveDataMasker:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def mask_dict(
|
def mask_dict(
|
||||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -5,6 +5,7 @@ import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType
|
||||||
|
|
||||||
|
@ -228,7 +229,11 @@ def unpack_defs(schema, defs):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def convert_anyof_null_to_nullable(schema):
|
def convert_anyof_null_to_nullable(schema, depth=0):
|
||||||
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||||
|
)
|
||||||
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
|
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
|
||||||
anyof = schema.get("anyOf", None)
|
anyof = schema.get("anyOf", None)
|
||||||
if anyof is not None:
|
if anyof is not None:
|
||||||
|
@ -256,11 +261,11 @@ def convert_anyof_null_to_nullable(schema):
|
||||||
properties = schema.get("properties", None)
|
properties = schema.get("properties", None)
|
||||||
if properties is not None:
|
if properties is not None:
|
||||||
for name, value in properties.items():
|
for name, value in properties.items():
|
||||||
convert_anyof_null_to_nullable(value)
|
convert_anyof_null_to_nullable(value, depth=depth + 1)
|
||||||
|
|
||||||
items = schema.get("items", None)
|
items = schema.get("items", None)
|
||||||
if items is not None:
|
if items is not None:
|
||||||
convert_anyof_null_to_nullable(items)
|
convert_anyof_null_to_nullable(items, depth=depth + 1)
|
||||||
|
|
||||||
|
|
||||||
def add_object_type(schema):
|
def add_object_type(schema):
|
||||||
|
|
|
@ -29,6 +29,7 @@ from litellm.types.utils import (
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
)
|
)
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -1522,7 +1523,7 @@ class ProxyConfig:
|
||||||
yaml.dump(new_config, config_file, default_flow_style=False)
|
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||||
|
|
||||||
def _check_for_os_environ_vars(
|
def _check_for_os_environ_vars(
|
||||||
self, config: dict, depth: int = 0, max_depth: int = 10
|
self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Check for os.environ/ variables in the config and replace them with the actual values.
|
Check for os.environ/ variables in the config and replace them with the actual values.
|
||||||
|
|
|
@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
|
||||||
"_check_for_os_environ_vars",
|
"_check_for_os_environ_vars",
|
||||||
"clean_message",
|
"clean_message",
|
||||||
"unpack_defs",
|
"unpack_defs",
|
||||||
"convert_to_nullable",
|
"convert_anyof_null_to_nullable", # has a set max depth
|
||||||
"add_object_type",
|
"add_object_type",
|
||||||
"strip_field",
|
"strip_field",
|
||||||
"_transform_prompt",
|
"_transform_prompt",
|
||||||
|
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
|
||||||
ignored_recursive_functions[file_path] = ignored
|
ignored_recursive_functions[file_path] = ignored
|
||||||
return recursive_functions, ignored_recursive_functions
|
return recursive_functions, ignored_recursive_functions
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
# raise exception if any recursive functions are found, except for the ignored ones
|
||||||
|
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
|
||||||
|
|
||||||
# Example usage
|
directory_path = "./litellm"
|
||||||
directory_path = "./litellm"
|
recursive_functions, ignored_recursive_functions = (
|
||||||
recursive_functions, ignored_recursive_functions = (
|
find_recursive_functions_in_directory(directory_path)
|
||||||
find_recursive_functions_in_directory(directory_path)
|
|
||||||
)
|
|
||||||
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions)
|
|
||||||
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
|
||||||
if len(recursive_functions) > 0:
|
|
||||||
raise Exception(
|
|
||||||
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
|
||||||
)
|
)
|
||||||
|
print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
|
||||||
|
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
||||||
|
|
||||||
|
if len(recursive_functions) > 0:
|
||||||
|
# raise exception if any recursive functions are found
|
||||||
|
for file, functions in recursive_functions.items():
|
||||||
|
print(
|
||||||
|
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||||
|
)
|
||||||
|
file, functions = list(recursive_functions.items())[0]
|
||||||
|
raise Exception(
|
||||||
|
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -112,6 +113,30 @@ def test_nested_anyof_conversion():
|
||||||
}
|
}
|
||||||
assert schema == expected
|
assert schema == expected
|
||||||
|
|
||||||
|
def test_anyof_with_excessive_nesting():
|
||||||
|
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||||
|
# generate a schema with excessive nesting
|
||||||
|
schema = {"type": "object", "properties": {}}
|
||||||
|
current = schema
|
||||||
|
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||||
|
current["properties"] = {
|
||||||
|
"nested": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
],
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current["properties"]["nested"]
|
||||||
|
|
||||||
|
|
||||||
|
# running the conversion will raise an error
|
||||||
|
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_supports_system_message():
|
async def test_get_supports_system_message():
|
||||||
"""Test get_supports_system_message with different models"""
|
"""Test get_supports_system_message with different models"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue