mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #9625 from BerriAI/litellm_mar_28_vertex_fix
Add support to Vertex AI transformation for anyOf union type with null fields
This commit is contained in:
commit
70cdc9fc50
8 changed files with 243 additions and 34 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -83,4 +83,4 @@ tests/llm_translation/test_vertex_key.json
|
||||||
litellm/proxy/migrations/0_init/migration.sql
|
litellm/proxy/migrations/0_init/migration.sql
|
||||||
litellm/proxy/db/migrations/0_init/migration.sql
|
litellm/proxy/db/migrations/0_init/migration.sql
|
||||||
litellm/proxy/db/migrations/*
|
litellm/proxy/db/migrations/*
|
||||||
litellm/proxy/migrations/*
|
litellm/proxy/migrations/*config.yaml
|
||||||
|
|
|
@ -4,6 +4,7 @@ ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
|
DEFAULT_MAX_RECURSE_DEPTH = 10
|
||||||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
|
|
||||||
def safe_dumps(data: Any, max_depth: int = 10) -> str:
|
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||||
"""
|
"""
|
||||||
Recursively serialize data while detecting circular references.
|
Recursively serialize data while detecting circular references.
|
||||||
If a circular reference is detected then a marker string is returned.
|
If a circular reference is detected then a marker string is returned.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Optional, Set
|
from typing import Any, Dict, Optional, Set
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
|
|
||||||
class SensitiveDataMasker:
|
class SensitiveDataMasker:
|
||||||
|
@ -39,7 +40,7 @@ class SensitiveDataMasker:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def mask_dict(
|
def mask_dict(
|
||||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -5,6 +5,7 @@ import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType
|
||||||
|
|
||||||
|
@ -177,7 +178,7 @@ def _build_vertex_schema(parameters: dict):
|
||||||
# * https://github.com/pydantic/pydantic/issues/1270
|
# * https://github.com/pydantic/pydantic/issues/1270
|
||||||
# * https://stackoverflow.com/a/58841311
|
# * https://stackoverflow.com/a/58841311
|
||||||
# * https://github.com/pydantic/pydantic/discussions/4872
|
# * https://github.com/pydantic/pydantic/discussions/4872
|
||||||
convert_to_nullable(parameters)
|
convert_anyof_null_to_nullable(parameters)
|
||||||
add_object_type(parameters)
|
add_object_type(parameters)
|
||||||
# Postprocessing
|
# Postprocessing
|
||||||
# 4. Suppress unnecessary title generation:
|
# 4. Suppress unnecessary title generation:
|
||||||
|
@ -228,34 +229,43 @@ def unpack_defs(schema, defs):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def convert_to_nullable(schema):
|
def convert_anyof_null_to_nullable(schema, depth=0):
|
||||||
anyof = schema.pop("anyOf", None)
|
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."
|
||||||
|
)
|
||||||
|
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
|
||||||
|
anyof = schema.get("anyOf", None)
|
||||||
if anyof is not None:
|
if anyof is not None:
|
||||||
if len(anyof) != 2:
|
contains_null = False
|
||||||
|
for atype in anyof:
|
||||||
|
if atype == {"type": "null"}:
|
||||||
|
# remove null type
|
||||||
|
anyof.remove(atype)
|
||||||
|
contains_null = True
|
||||||
|
|
||||||
|
if len(anyof) == 0:
|
||||||
|
# Edge case: response schema with only null type present is invalid in Vertex AI
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
"Invalid input: AnyOf schema with only null type is not supported. "
|
||||||
"Please provide an `Optional` type or a non-Union type."
|
"Please provide a non-null type."
|
||||||
)
|
)
|
||||||
a, b = anyof
|
|
||||||
if a == {"type": "null"}:
|
|
||||||
schema.update(b)
|
if contains_null:
|
||||||
elif b == {"type": "null"}:
|
# set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python
|
||||||
schema.update(a)
|
for atype in anyof:
|
||||||
else:
|
atype["nullable"] = True
|
||||||
raise ValueError(
|
|
||||||
"Invalid input: Type Unions are not supported, except for `Optional` types. "
|
|
||||||
"Please provide an `Optional` type or a non-Union type."
|
|
||||||
)
|
|
||||||
schema["nullable"] = True
|
|
||||||
|
|
||||||
properties = schema.get("properties", None)
|
properties = schema.get("properties", None)
|
||||||
if properties is not None:
|
if properties is not None:
|
||||||
for name, value in properties.items():
|
for name, value in properties.items():
|
||||||
convert_to_nullable(value)
|
convert_anyof_null_to_nullable(value, depth=depth + 1)
|
||||||
|
|
||||||
items = schema.get("items", None)
|
items = schema.get("items", None)
|
||||||
if items is not None:
|
if items is not None:
|
||||||
convert_to_nullable(items)
|
convert_anyof_null_to_nullable(items, depth=depth + 1)
|
||||||
|
|
||||||
|
|
||||||
def add_object_type(schema):
|
def add_object_type(schema):
|
||||||
|
|
|
@ -29,6 +29,7 @@ from litellm.types.utils import (
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
)
|
)
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -1524,7 +1525,7 @@ class ProxyConfig:
|
||||||
yaml.dump(new_config, config_file, default_flow_style=False)
|
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||||
|
|
||||||
def _check_for_os_environ_vars(
|
def _check_for_os_environ_vars(
|
||||||
self, config: dict, depth: int = 0, max_depth: int = 10
|
self, config: dict, depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Check for os.environ/ variables in the config and replace them with the actual values.
|
Check for os.environ/ variables in the config and replace them with the actual values.
|
||||||
|
|
|
@ -9,7 +9,7 @@ IGNORE_FUNCTIONS = [
|
||||||
"_check_for_os_environ_vars",
|
"_check_for_os_environ_vars",
|
||||||
"clean_message",
|
"clean_message",
|
||||||
"unpack_defs",
|
"unpack_defs",
|
||||||
"convert_to_nullable",
|
"convert_anyof_null_to_nullable", # has a set max depth
|
||||||
"add_object_type",
|
"add_object_type",
|
||||||
"strip_field",
|
"strip_field",
|
||||||
"_transform_prompt",
|
"_transform_prompt",
|
||||||
|
@ -76,15 +76,26 @@ def find_recursive_functions_in_directory(directory):
|
||||||
ignored_recursive_functions[file_path] = ignored
|
ignored_recursive_functions[file_path] = ignored
|
||||||
return recursive_functions, ignored_recursive_functions
|
return recursive_functions, ignored_recursive_functions
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
# Example usage
|
# Example usage
|
||||||
|
# raise exception if any recursive functions are found, except for the ignored ones
|
||||||
|
# this is used in the CI/CD pipeline to prevent recursive functions from being merged
|
||||||
|
|
||||||
directory_path = "./litellm"
|
directory_path = "./litellm"
|
||||||
recursive_functions, ignored_recursive_functions = (
|
recursive_functions, ignored_recursive_functions = (
|
||||||
find_recursive_functions_in_directory(directory_path)
|
find_recursive_functions_in_directory(directory_path)
|
||||||
)
|
)
|
||||||
print("ALL RECURSIVE FUNCTIONS: ", recursive_functions)
|
print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
|
||||||
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)
|
||||||
|
|
||||||
if len(recursive_functions) > 0:
|
if len(recursive_functions) > 0:
|
||||||
raise Exception(
|
# raise exception if any recursive functions are found
|
||||||
f"🚨 Recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
for file, functions in recursive_functions.items():
|
||||||
|
print(
|
||||||
|
f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||||
)
|
)
|
||||||
|
file, functions = list(recursive_functions.items())[0]
|
||||||
|
raise Exception(
|
||||||
|
f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -12,6 +13,7 @@ import litellm
|
||||||
from litellm.llms.vertex_ai.common_utils import (
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
get_vertex_location_from_url,
|
get_vertex_location_from_url,
|
||||||
get_vertex_project_id_from_url,
|
get_vertex_project_id_from_url,
|
||||||
|
convert_anyof_null_to_nullable
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +44,98 @@ async def test_get_vertex_location_from_url():
|
||||||
location = get_vertex_location_from_url(url)
|
location = get_vertex_location_from_url(url)
|
||||||
assert location is None
|
assert location is None
|
||||||
|
|
||||||
|
def test_basic_anyof_conversion():
|
||||||
|
"""Test basic conversion of anyOf with 'null'."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"example": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"example": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string", "nullable": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert schema == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_anyof_conversion():
|
||||||
|
"""Test nested conversion with 'anyOf' inside properties."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"outer": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"inner": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "array", "items": {"type": "string"}},
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"outer": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"inner": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
||||||
|
{"type": "string", "nullable": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert schema == expected
|
||||||
|
|
||||||
|
def test_anyof_with_excessive_nesting():
|
||||||
|
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||||
|
# generate a schema with excessive nesting
|
||||||
|
schema = {"type": "object", "properties": {}}
|
||||||
|
current = schema
|
||||||
|
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||||
|
current["properties"] = {
|
||||||
|
"nested": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
],
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current["properties"]["nested"]
|
||||||
|
|
||||||
|
|
||||||
|
# running the conversion will raise an error
|
||||||
|
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_supports_system_message():
|
async def test_get_supports_system_message():
|
||||||
|
@ -59,3 +153,93 @@ async def test_get_supports_system_message():
|
||||||
model="random-model-name", custom_llm_provider="vertex_ai"
|
model="random-model-name", custom_llm_provider="vertex_ai"
|
||||||
)
|
)
|
||||||
assert result == False
|
assert result == False
|
||||||
|
def test_basic_anyof_conversion():
|
||||||
|
"""Test basic conversion of anyOf with 'null'."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"example": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"example": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string", "nullable": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert schema == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_anyof_conversion():
|
||||||
|
"""Test nested conversion with 'anyOf' inside properties."""
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"outer": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"inner": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "array", "items": {"type": "string"}},
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"outer": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"inner": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "array", "items": {"type": "string"}, "nullable": True},
|
||||||
|
{"type": "string", "nullable": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert schema == expected
|
||||||
|
|
||||||
|
def test_anyof_with_excessive_nesting():
|
||||||
|
"""Test conversion with excessive nesting > max levels +1 deep."""
|
||||||
|
# generate a schema with excessive nesting
|
||||||
|
schema = {"type": "object", "properties": {}}
|
||||||
|
current = schema
|
||||||
|
for _ in range(DEFAULT_MAX_RECURSE_DEPTH + 1):
|
||||||
|
current["properties"] = {
|
||||||
|
"nested": {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "null"}
|
||||||
|
],
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current["properties"]["nested"]
|
||||||
|
|
||||||
|
|
||||||
|
# running the conversion will raise an error
|
||||||
|
with pytest.raises(ValueError, match=f"Max depth of {DEFAULT_MAX_RECURSE_DEPTH} exceeded while processing schema. Please check the schema for excessive nesting."):
|
||||||
|
convert_anyof_null_to_nullable(schema)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue