diff --git a/.gitignore b/.gitignore index 1cdedb83fc..aac1a61ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -83,4 +83,4 @@ tests/llm_translation/test_vertex_key.json litellm/proxy/migrations/0_init/migration.sql litellm/proxy/db/migrations/0_init/migration.sql litellm/proxy/db/migrations/* -litellm/proxy/migrations/* \ No newline at end of file +litellm/proxy/migrations/*config.yaml diff --git a/litellm/constants.py b/litellm/constants.py index de0a7e366d..b355a0f683 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5 DEFAULT_BATCH_SIZE = 512 DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_MAX_RETRIES = 2 +DEFAULT_MAX_RECURSE_DEPTH = 10 DEFAULT_FAILURE_THRESHOLD_PERCENT = ( 0.5 # default cooldown a deployment if 50% of requests fail in a given minute ) diff --git a/litellm/litellm_core_utils/safe_json_dumps.py b/litellm/litellm_core_utils/safe_json_dumps.py index 990c0ed561..7ad0038ecb 100644 --- a/litellm/litellm_core_utils/safe_json_dumps.py +++ b/litellm/litellm_core_utils/safe_json_dumps.py @@ -1,8 +1,9 @@ import json from typing import Any, Union +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH -def safe_dumps(data: Any, max_depth: int = 10) -> str: +def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str: """ Recursively serialize data while detecting circular references. If a circular reference is detected then a marker string is returned. diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index a1df115ff0..7800e5304f 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional, Set +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH class SensitiveDataMasker: @@ -39,7 +40,7 @@ class SensitiveDataMasker: return result def mask_dict( - self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10 + self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH ) -> Dict[str, Any]: if depth >= max_depth: return data diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 9c9db2f047..337445777a 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -5,6 +5,7 @@ import httpx import litellm from litellm import supports_response_schema, supports_system_messages, verbose_logger +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.vertex_ai import PartType @@ -177,7 +178,7 @@ def _build_vertex_schema(parameters: dict): # * https://github.com/pydantic/pydantic/issues/1270 # * https://stackoverflow.com/a/58841311 # * https://github.com/pydantic/pydantic/discussions/4872 - convert_to_nullable(parameters) + convert_anyof_null_to_nullable(parameters) add_object_type(parameters) # Postprocessing # 4. Suppress unnecessary title generation: @@ -228,34 +229,43 @@ def unpack_defs(schema, defs): continue -def convert_to_nullable(schema): - anyof = schema.pop("anyOf", None) +def convert_anyof_null_to_nullable(schema, depth=0): + if depth > DEFAULT_MAX_RECURSE_DEPTH: + raise ValueError( + f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting." + ) + """ Converts null objects within anyOf by removing them and adding nullable to all remaining objects """ + anyof = schema.get("anyOf", None) if anyof is not None: - if len(anyof) != 2: + contains_null = False + for atype in anyof: + if atype == {"type": "null"}: + # remove null type + anyof.remove(atype) + contains_null = True + + if len(anyof) == 0: + # Edge case: response schema with only null type present is invalid in Vertex AI raise ValueError( - "Invalid input: Type Unions are not supported, except for `Optional` types. " - "Please provide an `Optional` type or a non-Union type." + "Invalid input: AnyOf schema with only null type is not supported. " + "Please provide a non-null type." ) - a, b = anyof - if a == {"type": "null"}: - schema.update(b) - elif b == {"type": "null"}: - schema.update(a) - else: - raise ValueError( - "Invalid input: Type Unions are not supported, except for `Optional` types. " - "Please provide an `Optional` type or a non-Union type." - ) - schema["nullable"] = True + + + if contains_null: + # set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python + for atype in anyof: + atype["nullable"] = True + properties = schema.get("properties", None) if properties is not None: for name, value in properties.items(): - convert_to_nullable(value) + convert_anyof_null_to_nullable(value, depth=depth + 1) items = schema.get("items", None) if items is not None: - convert_to_nullable(items) + convert_anyof_null_to_nullable(items, depth=depth + 1) def add_object_type(schema): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c60cc1378e..7c9b60d5cb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -29,6 +29,7 @@ from litellm.types.utils import ( ModelResponseStream, TextCompletionResponse, ) +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -1524,7 +1525,7 @@ class ProxyConfig: yaml.dump(new_config, config_file, default_flow_style=False) def _check_for_os_environ_vars( - self, config: dict, depth: int = 0, max_depth: int = 10 + self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH ) -> dict: """ Check for os.environ/ variables in the config and replace them with the actual values. diff --git a/tests/code_coverage_tests/recursive_detector.py b/tests/code_coverage_tests/recursive_detector.py index 3e9c597b94..b748d1a517 100644 --- a/tests/code_coverage_tests/recursive_detector.py +++ b/tests/code_coverage_tests/recursive_detector.py @@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [ "_check_for_os_environ_vars", "clean_message", "unpack_defs", - "convert_to_nullable", + "convert_anyof_null_to_nullable", # has a set max depth "add_object_type", "strip_field", "_transform_prompt", @@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory): ignored_recursive_functions[file_path] = ignored return recursive_functions, ignored_recursive_functions +if __name__ == "__main__": + # Example usage + # raise exception if any recursive functions are found, except for the ignored ones + # this is used in the CI/CD pipeline to prevent recursive functions from being merged -# Example usage -directory_path = "./litellm" -recursive_functions, ignored_recursive_functions = ( - find_recursive_functions_in_directory(directory_path) -) -print("ALL RECURSIVE FUNCTIONS: ", recursive_functions) -print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions) -if len(recursive_functions) > 0: - raise Exception( - f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + directory_path = "./litellm" + recursive_functions, ignored_recursive_functions = ( + find_recursive_functions_in_directory(directory_path) ) + print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions) + print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions) + + if len(recursive_functions) > 0: + # raise exception if any recursive functions are found + for file, functions in recursive_functions.items(): + print( + f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + ) + file, functions = list(recursive_functions.items())[0] + raise Exception( + f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." + ) + diff --git a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py index b94d7495cb..5c2f70527c 100644 --- a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py +++ b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py @@ -1,6 +1,7 @@ import os import sys from unittest.mock import MagicMock, call, patch +from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH import pytest @@ -12,6 +13,7 @@ import litellm from litellm.llms.vertex_ai.common_utils import ( get_vertex_location_from_url, get_vertex_project_id_from_url, + convert_anyof_null_to_nullable ) @@ -42,6 +44,98 @@ async def test_get_vertex_location_from_url(): location = get_vertex_location_from_url(url) assert location is None +def test_basic_anyof_conversion(): + """Test basic conversion of anyOf with 'null'.""" + schema = { + "type": "object", + "properties": { + "example": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ] + } + } + } + + convert_anyof_null_to_nullable(schema) + + expected = { + "type": "object", + "properties": { + "example": { + "anyOf": [ + {"type": "string", "nullable": True} + ] + } + } + } + assert schema == expected + + +def test_nested_anyof_conversion(): + """Test nested conversion with 'anyOf' inside properties.""" + schema = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}}, + {"type": "string"}, + {"type": "null"} + ] + } + } + } + } + } + + convert_anyof_null_to_nullable(schema) + + expected = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}, "nullable": True}, + {"type": "string", "nullable": True} + ] + } + } + } + } + } + assert schema == expected + +def test_anyof_with_excessive_nesting(): + """Test conversion with excessive nesting > max levels +1 deep.""" + # generate a schema with excessive nesting + schema = {"type": "object", "properties": {}} + current = schema + for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1): + current["properties"] = { + "nested": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "properties": {} + } + } + current = current["properties"]["nested"] + + + # running the conversion will raise an error + with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."): + convert_anyof_null_to_nullable(schema) + + @pytest.mark.asyncio async def test_get_supports_system_message(): @@ -59,3 +153,93 @@ async def test_get_supports_system_message(): model="random-model-name", custom_llm_provider="vertex_ai" ) assert result == False +def test_basic_anyof_conversion(): + """Test basic conversion of anyOf with 'null'.""" + schema = { + "type": "object", + "properties": { + "example": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ] + } + } + } + + convert_anyof_null_to_nullable(schema) + + expected = { + "type": "object", + "properties": { + "example": { + "anyOf": [ + {"type": "string", "nullable": True} + ] + } + } + } + assert schema == expected + + +def test_nested_anyof_conversion(): + """Test nested conversion with 'anyOf' inside properties.""" + schema = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}}, + {"type": "string"}, + {"type": "null"} + ] + } + } + } + } + } + + convert_anyof_null_to_nullable(schema) + + expected = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": { + "inner": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}, "nullable": True}, + {"type": "string", "nullable": True} + ] + } + } + } + } + } + assert schema == expected + +def test_anyof_with_excessive_nesting(): + """Test conversion with excessive nesting > max levels +1 deep.""" + # generate a schema with excessive nesting + schema = {"type": "object", "properties": {}} + current = schema + for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1): + current["properties"] = { + "nested": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "properties": {} + } + } + current = current["properties"]["nested"] + + + # running the conversion will raise an error + with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."): + convert_anyof_null_to_nullable(schema)