Merge pull request #9625 from BerriAI/litellm_mar_28_vertex_fix

Add support to Vertex AI transformation for anyOf union type with null fields
This commit is contained in:
NickGrab 2025-03-28 16:09:29 -07:00 committed by GitHub
commit 70cdc9fc50
8 changed files with 243 additions and 34 deletions

2
.gitignore vendored
View file

@ -83,4 +83,4 @@ tests/llm_translation/test_vertex_key.json
litellm/proxy/migrations/0_init/migration.sql litellm/proxy/migrations/0_init/migration.sql
litellm/proxy/db/migrations/0_init/migration.sql litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/* litellm/proxy/db/migrations/*
litellm/proxy/migrations/* litellm/proxy/migrations/*config.yaml

View file

@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512 DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2 DEFAULT_MAX_RETRIES = 2
DEFAULT_MAX_RECURSE_DEPTH = 10
DEFAULT_FAILURE_THRESHOLD_PERCENT = ( DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute 0.5 # default cooldown a deployment if 50% of requests fail in a given minute
) )

View file

@ -1,8 +1,9 @@
import json import json
from typing import Any, Union from typing import Any, Union
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
def safe_dumps(data: Any, max_depth: int = 10) -> str: def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
""" """
Recursively serialize data while detecting circular references. Recursively serialize data while detecting circular references.
If a circular reference is detected then a marker string is returned. If a circular reference is detected then a marker string is returned.

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
class SensitiveDataMasker: class SensitiveDataMasker:
@ -39,7 +40,7 @@ class SensitiveDataMasker:
return result return result
def mask_dict( def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10 self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if depth >= max_depth: if depth >= max_depth:
return data return data

View file

@ -5,6 +5,7 @@ import httpx
import litellm import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType
@ -177,7 +178,7 @@ def _build_vertex_schema(parameters: dict):
# * https://github.com/pydantic/pydantic/issues/1270 # * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311 # * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872 # * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters) convert_anyof_null_to_nullable(parameters)
add_object_type(parameters) add_object_type(parameters)
# Postprocessing # Postprocessing
# 4. Suppress unnecessary title generation: # 4. Suppress unnecessary title generation:
@ -228,34 +229,43 @@ def unpack_defs(schema, defs):
continue continue
def convert_to_nullable(schema): def convert_anyof_null_to_nullable(schema, depth=0):
anyof = schema.pop("anyOf", None) if depth > DEFAULT_MAX_RECURSE_DEPTH:
raise ValueError(
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
)
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
anyof = schema.get("anyOf", None)
if anyof is not None: if anyof is not None:
if len(anyof) != 2: contains_null = False
for atype in anyof:
if atype == {"type": "null"}:
# remove null type
anyof.remove(atype)
contains_null = True
if len(anyof) == 0:
# Edge case: response schema with only null type present is invalid in Vertex AI
raise ValueError( raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. " "Invalid input: AnyOf schema with only null type is not supported. "
"Please provide an `Optional` type or a non-Union type." "Please provide a non-null type."
) )
a, b = anyof
if a == {"type": "null"}:
schema.update(b) if contains_null:
elif b == {"type": "null"}: # set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python
schema.update(a) for atype in anyof:
else: atype["nullable"] = True
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
)
schema["nullable"] = True
properties = schema.get("properties", None) properties = schema.get("properties", None)
if properties is not None: if properties is not None:
for name, value in properties.items(): for name, value in properties.items():
convert_to_nullable(value) convert_anyof_null_to_nullable(value, depth=depth + 1)
items = schema.get("items", None) items = schema.get("items", None)
if items is not None: if items is not None:
convert_to_nullable(items) convert_anyof_null_to_nullable(items, depth=depth + 1)
def add_object_type(schema): def add_object_type(schema):

View file

@ -29,6 +29,7 @@ from litellm.types.utils import (
ModelResponseStream, ModelResponseStream,
TextCompletionResponse, TextCompletionResponse,
) )
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -1524,7 +1525,7 @@ class ProxyConfig:
yaml.dump(new_config, config_file, default_flow_style=False) yaml.dump(new_config, config_file, default_flow_style=False)
def _check_for_os_environ_vars( def _check_for_os_environ_vars(
self, config: dict, depth: int = 0, max_depth: int = 10 self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
) -> dict: ) -> dict:
""" """
Check for os.environ/ variables in the config and replace them with the actual values. Check for os.environ/ variables in the config and replace them with the actual values.

View file

@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
"_check_for_os_environ_vars", "_check_for_os_environ_vars",
"clean_message", "clean_message",
"unpack_defs", "unpack_defs",
"convert_to_nullable", "convert_anyof_null_to_nullable", # has a set max depth
"add_object_type", "add_object_type",
"strip_field", "strip_field",
"_transform_prompt", "_transform_prompt",
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
ignored_recursive_functions[file_path] = ignored ignored_recursive_functions[file_path] = ignored
return recursive_functions, ignored_recursive_functions return recursive_functions, ignored_recursive_functions
if __name__ == "__main__":
# Example usage # Example usage
# raise exception if any recursive functions are found, except for the ignored ones
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
directory_path = "./litellm" directory_path = "./litellm"
recursive_functions, ignored_recursive_functions = ( recursive_functions, ignored_recursive_functions = (
find_recursive_functions_in_directory(directory_path) find_recursive_functions_in_directory(directory_path)
) )
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions) print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions) print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
if len(recursive_functions) > 0: if len(recursive_functions) > 0:
raise Exception( # raise exception if any recursive functions are found
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary." for file, functions in recursive_functions.items():
print(
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
) )
file, functions = list(recursive_functions.items())[0]
raise Exception(
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
)

View file

@ -1,6 +1,7 @@
import os import os
import sys import sys
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
import pytest import pytest
@ -12,6 +13,7 @@ import litellm
from litellm.llms.vertex_ai.common_utils import ( from litellm.llms.vertex_ai.common_utils import (
get_vertex_location_from_url, get_vertex_location_from_url,
get_vertex_project_id_from_url, get_vertex_project_id_from_url,
convert_anyof_null_to_nullable
) )
@ -42,6 +44,98 @@ async def test_get_vertex_location_from_url():
location = get_vertex_location_from_url(url) location = get_vertex_location_from_url(url)
assert location is None assert location is None
def test_basic_anyof_conversion():
"""Test basic conversion of anyOf with 'null'."""
schema = {
"type": "object",
"properties": {
"example": {
"anyOf": [
{"type": "string"},
{"type": "null"}
]
}
}
}
convert_anyof_null_to_nullable(schema)
expected = {
"type": "object",
"properties": {
"example": {
"anyOf": [
{"type": "string", "nullable": True}
]
}
}
}
assert schema == expected
def test_nested_anyof_conversion():
"""Test nested conversion with 'anyOf' inside properties."""
schema = {
"type": "object",
"properties": {
"outer": {
"type": "object",
"properties": {
"inner": {
"anyOf": [
{"type": "array", "items": {"type": "string"}},
{"type": "string"},
{"type": "null"}
]
}
}
}
}
}
convert_anyof_null_to_nullable(schema)
expected = {
"type": "object",
"properties": {
"outer": {
"type": "object",
"properties": {
"inner": {
"anyOf": [
{"type": "array", "items": {"type": "string"}, "nullable": True},
{"type": "string", "nullable": True}
]
}
}
}
}
}
assert schema == expected
def test_anyof_with_excessive_nesting():
"""Test conversion with excessive nesting > max levels +1 deep."""
# generate a schema with excessive nesting
schema = {"type": "object", "properties": {}}
current = schema
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
current["properties"] = {
"nested": {
"anyOf": [
{"type": "string"},
{"type": "null"}
],
"properties": {}
}
}
current = current["properties"]["nested"]
# running the conversion will raise an error
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
convert_anyof_null_to_nullable(schema)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_supports_system_message(): async def test_get_supports_system_message():
@ -59,3 +153,93 @@ async def test_get_supports_system_message():
model="random-model-name", custom_llm_provider="vertex_ai" model="random-model-name", custom_llm_provider="vertex_ai"
) )
assert result == False assert result == False
def test_basic_anyof_conversion():
"""Test basic conversion of anyOf with 'null'."""
schema = {
"type": "object",
"properties": {
"example": {
"anyOf": [
{"type": "string"},
{"type": "null"}
]
}
}
}
convert_anyof_null_to_nullable(schema)
expected = {
"type": "object",
"properties": {
"example": {
"anyOf": [
{"type": "string", "nullable": True}
]
}
}
}
assert schema == expected
def test_nested_anyof_conversion():
"""Test nested conversion with 'anyOf' inside properties."""
schema = {
"type": "object",
"properties": {
"outer": {
"type": "object",
"properties": {
"inner": {
"anyOf": [
{"type": "array", "items": {"type": "string"}},
{"type": "string"},
{"type": "null"}
]
}
}
}
}
}
convert_anyof_null_to_nullable(schema)
expected = {
"type": "object",
"properties": {
"outer": {
"type": "object",
"properties": {
"inner": {
"anyOf": [
{"type": "array", "items": {"type": "string"}, "nullable": True},
{"type": "string", "nullable": True}
]
}
}
}
}
}
assert schema == expected
def test_anyof_with_excessive_nesting():
"""Test conversion with excessive nesting > max levels +1 deep."""
# generate a schema with excessive nesting
schema = {"type": "object", "properties": {}}
current = schema
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
current["properties"] = {
"nested": {
"anyOf": [
{"type": "string"},
{"type": "null"}
],
"properties": {}
}
}
current = current["properties"]["nested"]
# running the conversion will raise an error
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
convert_anyof_null_to_nullable(schema)