fix(vertex_ai.py): move to only passing in accepted keys by vertex ai response schema (#8992)

* fix(vertex_ai.py): common_utils.py

move to only passing in accepted keys by vertex ai

prevent json schema compatible keys like $id, and $comment from causing vertex ai openapi calls to fail

* fix(test_vertex.py): add testing to ensure only accepted schema params passed in

* fix(common_utils.py): fix linting error

* test: update test

* test: accept function
This commit is contained in:
Krish Dholakia 2025-04-07 18:07:01 -07:00 committed by GitHub
parent 4a128cfd64
commit 8e3c7b2de0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 78 additions and 30 deletions

View file

@ -1,5 +1,5 @@
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
import re
from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
@ -7,7 +7,7 @@ import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
from litellm.types.llms.vertex_ai import PartType, Schema
class VertexAIError(BaseLLMException):
@ -168,6 +168,9 @@ def _build_vertex_schema(parameters: dict):
"""
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
"""
# Get valid fields from Schema TypedDict
valid_schema_fields = set(get_type_hints(Schema).keys())
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
@ -181,19 +184,49 @@ def _build_vertex_schema(parameters: dict):
convert_anyof_null_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_field(parameters, field_name="title")
# Filter out fields that don't exist in Schema
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
return filtered_parameters
strip_field(
parameters, field_name="$schema"
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
strip_field(
parameters, field_name="$id"
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
return parameters
def filter_schema_fields(
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
) -> Dict[str, Any]:
"""
Recursively filter a schema dictionary to keep only valid fields.
"""
if processed is None:
processed = set()
# Handle circular references
schema_id = id(schema_dict)
if schema_id in processed:
return schema_dict
processed.add(schema_id)
if not isinstance(schema_dict, dict):
return schema_dict
result = {}
for key, value in schema_dict.items():
if key not in valid_fields:
continue
if key == "properties" and isinstance(value, dict):
result[key] = {
k: filter_schema_fields(v, valid_fields, processed)
for k, v in value.items()
}
elif key == "items" and isinstance(value, dict):
result[key] = filter_schema_fields(value, valid_fields, processed)
elif key == "anyOf" and isinstance(value, list):
result[key] = [
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
]
else:
result[key] = value
return result
def unpack_defs(schema, defs):

View file

@ -87,12 +87,27 @@ class SystemInstructions(TypedDict):
class Schema(TypedDict, total=False):
type: Literal["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]
format: str
title: str
description: str
enum: List[str]
items: List["Schema"]
properties: "Schema"
required: List[str]
nullable: bool
default: Any
items: "Schema"
minItems: str
maxItems: str
enum: List[str]
properties: Dict[str, "Schema"]
propertyOrdering: List[str]
required: List[str]
minProperties: str
maxProperties: str
minimum: float
maximum: float
minLength: str
maxLength: str
pattern: str
example: Any
anyOf: List["Schema"]
class FunctionDeclaration(TypedDict, total=False):

View file

@ -5,6 +5,7 @@ IGNORE_FUNCTIONS = [
"_format_type",
"_remove_additional_properties",
"_remove_strict_from_schema",
"filter_schema_fields",
"text_completion",
"_check_for_os_environ_vars",
"clean_message",

View file

@ -63,26 +63,24 @@ def test_completion_pydantic_obj_2():
"events": {
"items": {
"properties": {
"name": {"type": "string"},
"date": {"type": "string"},
"name": {"title": "Name", "type": "string"},
"date": {"title": "Date", "type": "string"},
"participants": {
"items": {"type": "string"},
"title": "Participants",
"type": "array",
},
},
"required": [
"name",
"date",
"participants",
],
"required": ["name", "date", "participants"],
"title": "CalendarEvent",
"type": "object",
},
"title": "Events",
"type": "array",
}
},
"required": [
"events",
],
"required": ["events"],
"title": "EventsList",
"type": "object",
},
},
@ -91,12 +89,13 @@ def test_completion_pydantic_obj_2():
with patch.object(client, "post", new=MagicMock()) as mock_post:
mock_post.return_value = expected_request_body
try:
litellm.completion(
response = litellm.completion(
model="gemini/gemini-1.5-pro",
messages=messages,
response_format=EventsList,
client=client,
)
# print(response)
except Exception as e:
print(e)
@ -115,7 +114,7 @@ def test_build_vertex_schema():
schema = {
"type": "object",
"$id": "my-special-id",
"my-random-key": "my-random-value",
"properties": {
"recipes": {
"type": "array",
@ -134,7 +133,7 @@ def test_build_vertex_schema():
assert new_schema["type"] == schema["type"]
assert new_schema["properties"] == schema["properties"]
assert "required" in new_schema and new_schema["required"] == schema["required"]
assert "$id" not in new_schema
assert "my-random-key" not in new_schema
@pytest.mark.parametrize(