mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Handle Annotated types more correctly
This commit is contained in:
parent
53ab18d6bb
commit
73b71d9689
1 changed files with 80 additions and 48 deletions
|
@ -11,6 +11,7 @@ from enum import Enum
|
|||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
@ -26,6 +27,12 @@ def is_list_of_primitives(field_type):
|
|||
return False
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
)
|
||||
|
||||
|
||||
def get_literal_values(field):
|
||||
"""Extract literal values from a field if it's a Literal type."""
|
||||
if get_origin(field.annotation) is Literal:
|
||||
|
@ -52,6 +59,60 @@ def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any)
|
|||
return value
|
||||
|
||||
|
||||
def is_discriminated_union(typ) -> bool:
|
||||
if isinstance(typ, FieldInfo):
|
||||
return typ.discriminator
|
||||
else:
|
||||
if not (get_origin(typ) is Annotated):
|
||||
return False
|
||||
args = get_args(typ)
|
||||
return len(args) >= 2 and args[1].discriminator
|
||||
|
||||
|
||||
def prompt_for_discriminated_union(
|
||||
field_name,
|
||||
typ,
|
||||
existing_value,
|
||||
):
|
||||
if isinstance(typ, FieldInfo):
|
||||
inner_type = typ.annotation
|
||||
discriminator = typ.discriminator
|
||||
else:
|
||||
args = get_args(typ)
|
||||
inner_type = args[0]
|
||||
discriminator = args[1].discriminator
|
||||
|
||||
union_types = get_args(inner_type)
|
||||
# Find the discriminator field in each union type
|
||||
type_map = {}
|
||||
for t in union_types:
|
||||
disc_field = t.__fields__[discriminator]
|
||||
literal_values = get_literal_values(disc_field)
|
||||
if literal_values:
|
||||
for value in literal_values:
|
||||
type_map[value] = t
|
||||
|
||||
while True:
|
||||
discriminator_value = input(
|
||||
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
|
||||
)
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator) != discriminator_value
|
||||
):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
# Set the discriminator field in the sub-config
|
||||
setattr(sub_config, discriminator, discriminator_value)
|
||||
return sub_config
|
||||
else:
|
||||
print(f"Invalid {discriminator}. Please try again.")
|
||||
|
||||
|
||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||
# We should add handling for the most common cases to tide us over.
|
||||
#
|
||||
|
@ -73,7 +134,6 @@ def prompt_for_config(
|
|||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
|
||||
existing_value = (
|
||||
getattr(existing_config, field_name) if existing_config else None
|
||||
)
|
||||
|
@ -107,50 +167,13 @@ def prompt_for_config(
|
|||
)
|
||||
continue
|
||||
|
||||
# Check if the field is a discriminated union
|
||||
if get_origin(field_type) is Annotated:
|
||||
inner_type = get_args(field_type)[0]
|
||||
if get_origin(inner_type) is Union:
|
||||
discriminator = field.field_info.discriminator
|
||||
if discriminator:
|
||||
union_types = get_args(inner_type)
|
||||
# Find the discriminator field in each union type
|
||||
type_map = {}
|
||||
for t in union_types:
|
||||
disc_field = t.__fields__[discriminator]
|
||||
literal_values = get_literal_values(disc_field)
|
||||
if literal_values:
|
||||
for value in literal_values:
|
||||
type_map[value] = t
|
||||
if is_discriminated_union(field):
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name, field, existing_value
|
||||
)
|
||||
continue
|
||||
|
||||
while True:
|
||||
discriminator_value = input(
|
||||
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
|
||||
)
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator)
|
||||
!= discriminator_value
|
||||
):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
config_data[field_name] = sub_config
|
||||
# Set the discriminator field in the sub-config
|
||||
setattr(sub_config, discriminator, discriminator_value)
|
||||
break
|
||||
else:
|
||||
print(f"Invalid {discriminator}. Please try again.")
|
||||
continue
|
||||
|
||||
if (
|
||||
is_optional(field_type)
|
||||
and inspect.isclass(get_non_none_type(field_type))
|
||||
and issubclass(get_non_none_type(field_type), BaseModel)
|
||||
):
|
||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
|
@ -158,11 +181,20 @@ def prompt_for_config(
|
|||
nested_type = get_non_none_type(field_type)
|
||||
print(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif (
|
||||
inspect.isclass(field_type)
|
||||
and issubclass(field_type, BaseModel)
|
||||
and len(field_type.__fields__) > 0
|
||||
elif is_optional(field_type) and is_discriminated_union(
|
||||
get_non_none_type(field_type)
|
||||
):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
continue
|
||||
nested_type = get_non_none_type(field_type)
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name,
|
||||
nested_type,
|
||||
existing_value,
|
||||
)
|
||||
elif can_recurse(field_type):
|
||||
print(f"\nEntering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(
|
||||
field_type,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue