diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index 6c53477d8..4f92ec7d9 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -11,6 +11,7 @@ from enum import Enum from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union from pydantic import BaseModel +from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType from typing_extensions import Annotated @@ -26,6 +27,12 @@ def is_list_of_primitives(field_type): return False +def can_recurse(typ): + return ( + inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0 + ) + + def get_literal_values(field): """Extract literal values from a field if it's a Literal type.""" if get_origin(field.annotation) is Literal: @@ -52,6 +59,60 @@ def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any) return value +def is_discriminated_union(typ) -> bool: + if isinstance(typ, FieldInfo): + return typ.discriminator + else: + if not (get_origin(typ) is Annotated): + return False + args = get_args(typ) + return len(args) >= 2 and args[1].discriminator + + +def prompt_for_discriminated_union( + field_name, + typ, + existing_value, +): + if isinstance(typ, FieldInfo): + inner_type = typ.annotation + discriminator = typ.discriminator + else: + args = get_args(typ) + inner_type = args[0] + discriminator = args[1].discriminator + + union_types = get_args(inner_type) + # Find the discriminator field in each union type + type_map = {} + for t in union_types: + disc_field = t.__fields__[discriminator] + literal_values = get_literal_values(disc_field) + if literal_values: + for value in literal_values: + type_map[value] = t + + while True: + discriminator_value = input( + f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): " + ) + if discriminator_value in type_map: + chosen_type = type_map[discriminator_value] + print(f"\nConfiguring {chosen_type.__name__}:") + + if existing_value and ( + getattr(existing_value, discriminator) != discriminator_value + ): + existing_value = None + + sub_config = prompt_for_config(chosen_type, existing_value) + # Set the discriminator field in the sub-config + setattr(sub_config, discriminator, discriminator_value) + return sub_config + else: + print(f"Invalid {discriminator}. Please try again.") + + # This is somewhat elaborate, but does not purport to be comprehensive in any way. # We should add handling for the most common cases to tide us over. # @@ -73,7 +134,6 @@ def prompt_for_config( for field_name, field in config_type.__fields__.items(): field_type = field.annotation - existing_value = ( getattr(existing_config, field_name) if existing_config else None ) @@ -107,50 +167,13 @@ def prompt_for_config( ) continue - # Check if the field is a discriminated union - if get_origin(field_type) is Annotated: - inner_type = get_args(field_type)[0] - if get_origin(inner_type) is Union: - discriminator = field.field_info.discriminator - if discriminator: - union_types = get_args(inner_type) - # Find the discriminator field in each union type - type_map = {} - for t in union_types: - disc_field = t.__fields__[discriminator] - literal_values = get_literal_values(disc_field) - if literal_values: - for value in literal_values: - type_map[value] = t + if is_discriminated_union(field): + config_data[field_name] = prompt_for_discriminated_union( + field_name, field, existing_value + ) + continue - while True: - discriminator_value = input( - f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): " - ) - if discriminator_value in type_map: - chosen_type = type_map[discriminator_value] - print(f"\nConfiguring {chosen_type.__name__}:") - - if existing_value and ( - getattr(existing_value, discriminator) - != discriminator_value - ): - existing_value = None - - sub_config = prompt_for_config(chosen_type, existing_value) - config_data[field_name] = sub_config - # Set the discriminator field in the sub-config - setattr(sub_config, discriminator, discriminator_value) - break - else: - print(f"Invalid {discriminator}. Please try again.") - continue - - if ( - is_optional(field_type) - and inspect.isclass(get_non_none_type(field_type)) - and issubclass(get_non_none_type(field_type), BaseModel) - ): + if is_optional(field_type) and can_recurse(get_non_none_type(field_type)): prompt = f"Do you want to configure {field_name}? (y/n): " if input(prompt).lower() == "n": config_data[field_name] = None @@ -158,11 +181,20 @@ def prompt_for_config( nested_type = get_non_none_type(field_type) print(f"Entering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config(nested_type, existing_value) - elif ( - inspect.isclass(field_type) - and issubclass(field_type, BaseModel) - and len(field_type.__fields__) > 0 + elif is_optional(field_type) and is_discriminated_union( + get_non_none_type(field_type) ): + prompt = f"Do you want to configure {field_name}? (y/n): " + if input(prompt).lower() == "n": + config_data[field_name] = None + continue + nested_type = get_non_none_type(field_type) + config_data[field_name] = prompt_for_discriminated_union( + field_name, + nested_type, + existing_value, + ) + elif can_recurse(field_type): print(f"\nEntering sub-configuration for {field_name}:") config_data[field_name] = prompt_for_config( field_type,