8864 Add support for anyOf union type while handling null fields

This commit is contained in:
Nicholas Grabar 2025-03-25 22:37:28 -07:00
parent 122ee634f4
commit f68cc26f15
3 changed files with 95 additions and 20 deletions

View file

@ -160,7 +160,7 @@ def _build_vertex_schema(parameters: dict):
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters)
convert_anyof_null_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
@ -211,34 +211,39 @@ def unpack_defs(schema, defs):
continue
def convert_to_nullable(schema):
anyof = schema.pop("anyOf", None)
def convert_anyof_null_to_nullable(schema):
""" Converts null objects within anyOf by removing them and adding nullable to all remaining objects """
anyof = schema.get("anyOf", None)
if anyof is not None:
if len(anyof) != 2:
contains_null = False
for atype in anyof:
if atype == {"type": "null"}:
# remove null type
anyof.remove(atype)
contains_null = True
if len(anyof) == 0:
# Edge case: response schema with only null type present is invalid in Vertex AI
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
"Invalid input: AnyOf schema with only null type is not supported. "
"Please provide a non-null type."
)
a, b = anyof
if a == {"type": "null"}:
schema.update(b)
elif b == {"type": "null"}:
schema.update(a)
else:
raise ValueError(
"Invalid input: Type Unions are not supported, except for `Optional` types. "
"Please provide an `Optional` type or a non-Union type."
)
schema["nullable"] = True
if contains_null:
# set all types to nullable following guidance found here: https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-controlled-generation-response-schema-3#generativeaionvertexai_gemini_controlled_generation_response_schema_3-python
for atype in anyof:
atype["nullable"] = True
properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_to_nullable(value)
convert_anyof_null_to_nullable(value)
items = schema.get("items", None)
if items is not None:
convert_to_nullable(items)
convert_anyof_null_to_nullable(items)
def add_object_type(schema):