fix: direct client pydantic type casting (#1145)

# What does this PR do?
- Closes #1142 
- Root cause is due to having `Union[str, AgenToolGroupWithArgs]`

## Test Plan
- Test with script described in issue. 

- Print out final converted pydantic object
<img width="1470" alt="image"
src="https://github.com/user-attachments/assets/15dc9cd0-f37a-4b91-905f-3fe4f59a08c6"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-18 16:07:54 -08:00 committed by GitHub
parent 8585b95a28
commit e8cb9e0adb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 26 additions and 13 deletions

View file

@ -77,7 +77,7 @@ def typeannotation(
"""
def wrap(cls: Type[T]) -> Type[T]:
setattr(cls, "__repr__", _compact_dataclass_repr)
cls.__repr__ = _compact_dataclass_repr
if not dataclasses.is_dataclass(cls):
cls = dataclasses.dataclass( # type: ignore[call-overload]
cls,

View file

@ -203,7 +203,7 @@ def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str)
if type_def.default is not dataclasses.MISSING:
raise TypeError("disallowed: `default` for top-level type definitions")
setattr(type_def.type, "__module__", module.__name__)
type_def.type.__module__ = module.__name__
setattr(module, type_name, type_def.type)
return node_to_typedef(module, class_name, top_node).type

View file

@ -325,7 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}"
)
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data))
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False))
class UnionDeserializer(Deserializer):

View file

@ -263,8 +263,8 @@ def extend_enum(
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
# assign the newly created type to the same module where the extending class is defined
setattr(enum_class, "__module__", extend.__module__)
setattr(enum_class, "__doc__", extend.__doc__)
enum_class.__module__ = extend.__module__
enum_class.__doc__ = extend.__doc__
setattr(sys.modules[extend.__module__], extend.__name__, enum_class)
return enum.unique(enum_class)
@ -874,6 +874,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
for tuple_item_type, item in zip(
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
(item for item in obj),
strict=False,
)
)
elif origin_type is Union:
@ -954,6 +955,7 @@ class RecursiveChecker:
for tuple_item_type, item in zip(
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
(item for item in obj),
strict=False,
)
)
elif origin_type is Union:

View file

@ -216,7 +216,7 @@ class TypedTupleSerializer(Serializer[tuple]):
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
def generate(self, obj: tuple) -> List[JsonType]:
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)]
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
class CustomSerializer(Serializer):