Handle Annotated types more correctly

This commit is contained in:
Ashwin Bharambe 2024-09-14 12:22:13 -07:00
parent 53ab18d6bb
commit 73b71d9689

View file

@ -11,6 +11,7 @@ from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
@ -26,6 +27,12 @@ def is_list_of_primitives(field_type):
return False
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
)
def get_literal_values(field):
"""Extract literal values from a field if it's a Literal type."""
if get_origin(field.annotation) is Literal:
@ -52,6 +59,60 @@ def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any)
return value
def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if not (get_origin(typ) is Annotated):
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
def prompt_for_discriminated_union(
field_name,
typ,
existing_value,
):
if isinstance(typ, FieldInfo):
inner_type = typ.annotation
discriminator = typ.discriminator
else:
args = get_args(typ)
inner_type = args[0]
discriminator = args[1].discriminator
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
while True:
discriminator_value = input(
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
return sub_config
else:
print(f"Invalid {discriminator}. Please try again.")
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
# We should add handling for the most common cases to tide us over.
#
@ -73,7 +134,6 @@ def prompt_for_config(
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
@ -107,50 +167,13 @@ def prompt_for_config(
)
continue
# Check if the field is a discriminated union
if get_origin(field_type) is Annotated:
inner_type = get_args(field_type)[0]
if get_origin(inner_type) is Union:
discriminator = field.field_info.discriminator
if discriminator:
union_types = get_args(inner_type)
# Find the discriminator field in each union type
type_map = {}
for t in union_types:
disc_field = t.__fields__[discriminator]
literal_values = get_literal_values(disc_field)
if literal_values:
for value in literal_values:
type_map[value] = t
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(
field_name, field, existing_value
)
continue
while True:
discriminator_value = input(
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
)
if discriminator_value in type_map:
chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator)
!= discriminator_value
):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
config_data[field_name] = sub_config
# Set the discriminator field in the sub-config
setattr(sub_config, discriminator, discriminator_value)
break
else:
print(f"Invalid {discriminator}. Please try again.")
continue
if (
is_optional(field_type)
and inspect.isclass(get_non_none_type(field_type))
and issubclass(get_non_none_type(field_type), BaseModel)
):
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
@ -158,11 +181,20 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type)
print(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif (
inspect.isclass(field_type)
and issubclass(field_type, BaseModel)
and len(field_type.__fields__) > 0
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
continue
nested_type = get_non_none_type(field_type)
config_data[field_name] = prompt_for_discriminated_union(
field_name,
nested_type,
existing_value,
)
elif can_recurse(field_type):
print(f"\nEntering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(
field_type,