fix(vertex_ai.py): move to only passing in accepted keys by vertex ai response schema (#8992)

* fix(vertex_ai.py): common_utils.py

move to only passing in accepted keys by vertex ai

prevent json schema compatible keys like $id, and $comment from causing vertex ai openapi calls to fail

* fix(test_vertex.py): add testing to ensure only accepted schema params passed in

* fix(common_utils.py): fix linting error

* test: update test

* test: accept function
This commit is contained in:
Krish Dholakia 2025-04-07 18:07:01 -07:00 committed by GitHub
parent 4a128cfd64
commit 8e3c7b2de0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 78 additions and 30 deletions

View file

@ -1,5 +1,5 @@
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
import re
from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
@ -7,7 +7,7 @@ import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
from litellm.types.llms.vertex_ai import PartType, Schema
class VertexAIError(BaseLLMException):
@ -168,6 +168,9 @@ def _build_vertex_schema(parameters: dict):
"""
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
"""
# Get valid fields from Schema TypedDict
valid_schema_fields = set(get_type_hints(Schema).keys())
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
@ -181,19 +184,49 @@ def _build_vertex_schema(parameters: dict):
convert_anyof_null_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_field(parameters, field_name="title")
# Filter out fields that don't exist in Schema
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
return filtered_parameters
strip_field(
parameters, field_name="$schema"
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
strip_field(
parameters, field_name="$id"
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
return parameters
def filter_schema_fields(
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
) -> Dict[str, Any]:
"""
Recursively filter a schema dictionary to keep only valid fields.
"""
if processed is None:
processed = set()
# Handle circular references
schema_id = id(schema_dict)
if schema_id in processed:
return schema_dict
processed.add(schema_id)
if not isinstance(schema_dict, dict):
return schema_dict
result = {}
for key, value in schema_dict.items():
if key not in valid_fields:
continue
if key == "properties" and isinstance(value, dict):
result[key] = {
k: filter_schema_fields(v, valid_fields, processed)
for k, v in value.items()
}
elif key == "items" and isinstance(value, dict):
result[key] = filter_schema_fields(value, valid_fields, processed)
elif key == "anyOf" and isinstance(value, list):
result[key] = [
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
]
else:
result[key] = value
return result
def unpack_defs(schema, defs):