mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
Handle Annotated types more correctly
This commit is contained in:
parent
53ab18d6bb
commit
73b71d9689
1 changed files with 80 additions and 48 deletions
|
@ -11,6 +11,7 @@ from enum import Enum
|
||||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -26,6 +27,12 @@ def is_list_of_primitives(field_type):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def can_recurse(typ):
|
||||||
|
return (
|
||||||
|
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_literal_values(field):
|
def get_literal_values(field):
|
||||||
"""Extract literal values from a field if it's a Literal type."""
|
"""Extract literal values from a field if it's a Literal type."""
|
||||||
if get_origin(field.annotation) is Literal:
|
if get_origin(field.annotation) is Literal:
|
||||||
|
@ -52,6 +59,60 @@ def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def is_discriminated_union(typ) -> bool:
|
||||||
|
if isinstance(typ, FieldInfo):
|
||||||
|
return typ.discriminator
|
||||||
|
else:
|
||||||
|
if not (get_origin(typ) is Annotated):
|
||||||
|
return False
|
||||||
|
args = get_args(typ)
|
||||||
|
return len(args) >= 2 and args[1].discriminator
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_for_discriminated_union(
|
||||||
|
field_name,
|
||||||
|
typ,
|
||||||
|
existing_value,
|
||||||
|
):
|
||||||
|
if isinstance(typ, FieldInfo):
|
||||||
|
inner_type = typ.annotation
|
||||||
|
discriminator = typ.discriminator
|
||||||
|
else:
|
||||||
|
args = get_args(typ)
|
||||||
|
inner_type = args[0]
|
||||||
|
discriminator = args[1].discriminator
|
||||||
|
|
||||||
|
union_types = get_args(inner_type)
|
||||||
|
# Find the discriminator field in each union type
|
||||||
|
type_map = {}
|
||||||
|
for t in union_types:
|
||||||
|
disc_field = t.__fields__[discriminator]
|
||||||
|
literal_values = get_literal_values(disc_field)
|
||||||
|
if literal_values:
|
||||||
|
for value in literal_values:
|
||||||
|
type_map[value] = t
|
||||||
|
|
||||||
|
while True:
|
||||||
|
discriminator_value = input(
|
||||||
|
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
|
||||||
|
)
|
||||||
|
if discriminator_value in type_map:
|
||||||
|
chosen_type = type_map[discriminator_value]
|
||||||
|
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||||
|
|
||||||
|
if existing_value and (
|
||||||
|
getattr(existing_value, discriminator) != discriminator_value
|
||||||
|
):
|
||||||
|
existing_value = None
|
||||||
|
|
||||||
|
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||||
|
# Set the discriminator field in the sub-config
|
||||||
|
setattr(sub_config, discriminator, discriminator_value)
|
||||||
|
return sub_config
|
||||||
|
else:
|
||||||
|
print(f"Invalid {discriminator}. Please try again.")
|
||||||
|
|
||||||
|
|
||||||
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
# This is somewhat elaborate, but does not purport to be comprehensive in any way.
|
||||||
# We should add handling for the most common cases to tide us over.
|
# We should add handling for the most common cases to tide us over.
|
||||||
#
|
#
|
||||||
|
@ -73,7 +134,6 @@ def prompt_for_config(
|
||||||
|
|
||||||
for field_name, field in config_type.__fields__.items():
|
for field_name, field in config_type.__fields__.items():
|
||||||
field_type = field.annotation
|
field_type = field.annotation
|
||||||
|
|
||||||
existing_value = (
|
existing_value = (
|
||||||
getattr(existing_config, field_name) if existing_config else None
|
getattr(existing_config, field_name) if existing_config else None
|
||||||
)
|
)
|
||||||
|
@ -107,50 +167,13 @@ def prompt_for_config(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the field is a discriminated union
|
if is_discriminated_union(field):
|
||||||
if get_origin(field_type) is Annotated:
|
config_data[field_name] = prompt_for_discriminated_union(
|
||||||
inner_type = get_args(field_type)[0]
|
field_name, field, existing_value
|
||||||
if get_origin(inner_type) is Union:
|
)
|
||||||
discriminator = field.field_info.discriminator
|
continue
|
||||||
if discriminator:
|
|
||||||
union_types = get_args(inner_type)
|
|
||||||
# Find the discriminator field in each union type
|
|
||||||
type_map = {}
|
|
||||||
for t in union_types:
|
|
||||||
disc_field = t.__fields__[discriminator]
|
|
||||||
literal_values = get_literal_values(disc_field)
|
|
||||||
if literal_values:
|
|
||||||
for value in literal_values:
|
|
||||||
type_map[value] = t
|
|
||||||
|
|
||||||
while True:
|
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||||
discriminator_value = input(
|
|
||||||
f"Enter the {discriminator} (options: {', '.join(type_map.keys())}): "
|
|
||||||
)
|
|
||||||
if discriminator_value in type_map:
|
|
||||||
chosen_type = type_map[discriminator_value]
|
|
||||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
|
||||||
|
|
||||||
if existing_value and (
|
|
||||||
getattr(existing_value, discriminator)
|
|
||||||
!= discriminator_value
|
|
||||||
):
|
|
||||||
existing_value = None
|
|
||||||
|
|
||||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
|
||||||
config_data[field_name] = sub_config
|
|
||||||
# Set the discriminator field in the sub-config
|
|
||||||
setattr(sub_config, discriminator, discriminator_value)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"Invalid {discriminator}. Please try again.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
is_optional(field_type)
|
|
||||||
and inspect.isclass(get_non_none_type(field_type))
|
|
||||||
and issubclass(get_non_none_type(field_type), BaseModel)
|
|
||||||
):
|
|
||||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||||
if input(prompt).lower() == "n":
|
if input(prompt).lower() == "n":
|
||||||
config_data[field_name] = None
|
config_data[field_name] = None
|
||||||
|
@ -158,11 +181,20 @@ def prompt_for_config(
|
||||||
nested_type = get_non_none_type(field_type)
|
nested_type = get_non_none_type(field_type)
|
||||||
print(f"Entering sub-configuration for {field_name}:")
|
print(f"Entering sub-configuration for {field_name}:")
|
||||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||||
elif (
|
elif is_optional(field_type) and is_discriminated_union(
|
||||||
inspect.isclass(field_type)
|
get_non_none_type(field_type)
|
||||||
and issubclass(field_type, BaseModel)
|
|
||||||
and len(field_type.__fields__) > 0
|
|
||||||
):
|
):
|
||||||
|
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||||
|
if input(prompt).lower() == "n":
|
||||||
|
config_data[field_name] = None
|
||||||
|
continue
|
||||||
|
nested_type = get_non_none_type(field_type)
|
||||||
|
config_data[field_name] = prompt_for_discriminated_union(
|
||||||
|
field_name,
|
||||||
|
nested_type,
|
||||||
|
existing_value,
|
||||||
|
)
|
||||||
|
elif can_recurse(field_type):
|
||||||
print(f"\nEntering sub-configuration for {field_name}:")
|
print(f"\nEntering sub-configuration for {field_name}:")
|
||||||
config_data[field_name] = prompt_for_config(
|
config_data[field_name] = prompt_for_config(
|
||||||
field_type,
|
field_type,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue