fix: direct client pydantic type casting (#1145)

# What does this PR do?
- Closes #1142 
- Root cause is due to having `Union[str, AgenToolGroupWithArgs]`

## Test Plan
- Test with script described in issue. 

- Print out final converted pydantic object
<img width="1470" alt="image"
src="https://github.com/user-attachments/assets/15dc9cd0-f37a-4b91-905f-3fe4f59a08c6"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-18 16:07:54 -08:00 committed by GitHub
parent 8585b95a28
commit e8cb9e0adb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 26 additions and 13 deletions

View file

@ -13,7 +13,7 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Optional, TypeVar, get_args, get_origin from typing import Any, Optional, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
@ -81,12 +81,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
return value return value
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is list: if origin is list:
item_type = get_args(annotation)[0] item_type = get_args(annotation)[0]
try: try:
return [convert_to_pydantic(item_type, item) for item in value] return [convert_to_pydantic(item_type, item) for item in value]
except Exception: except Exception:
print(f"Error converting list {value}") print(f"Error converting list {value} into {item_type}")
return value return value
elif origin is dict: elif origin is dict:
@ -94,17 +95,26 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try: try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception: except Exception:
print(f"Error converting dict {value}") print(f"Error converting dict {value} into {val_type}")
return value return value
try: try:
# Handle Pydantic models and discriminated unions # Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value) return TypeAdapter(annotation).validate_python(value)
except Exception as e: except Exception as e:
cprint( # TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", # We should get rid of any non-discriminated unions in the API schema.
"yellow", if origin is Union:
) for union_type in get_args(annotation):
try:
return convert_to_pydantic(union_type, value)
except Exception:
continue
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value return value
@ -421,4 +431,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(param.annotation, value) converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body return converted_body

View file

@ -77,7 +77,7 @@ def typeannotation(
""" """
def wrap(cls: Type[T]) -> Type[T]: def wrap(cls: Type[T]) -> Type[T]:
setattr(cls, "__repr__", _compact_dataclass_repr) cls.__repr__ = _compact_dataclass_repr
if not dataclasses.is_dataclass(cls): if not dataclasses.is_dataclass(cls):
cls = dataclasses.dataclass( # type: ignore[call-overload] cls = dataclasses.dataclass( # type: ignore[call-overload]
cls, cls,

View file

@ -203,7 +203,7 @@ def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str)
if type_def.default is not dataclasses.MISSING: if type_def.default is not dataclasses.MISSING:
raise TypeError("disallowed: `default` for top-level type definitions") raise TypeError("disallowed: `default` for top-level type definitions")
setattr(type_def.type, "__module__", module.__name__) type_def.type.__module__ = module.__name__
setattr(module, type_name, type_def.type) setattr(module, type_name, type_def.type)
return node_to_typedef(module, class_name, top_node).type return node_to_typedef(module, class_name, top_node).type

View file

@ -325,7 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}" f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}"
) )
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data)) return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False))
class UnionDeserializer(Deserializer): class UnionDeserializer(Deserializer):

View file

@ -263,8 +263,8 @@ def extend_enum(
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
# assign the newly created type to the same module where the extending class is defined # assign the newly created type to the same module where the extending class is defined
setattr(enum_class, "__module__", extend.__module__) enum_class.__module__ = extend.__module__
setattr(enum_class, "__doc__", extend.__doc__) enum_class.__doc__ = extend.__doc__
setattr(sys.modules[extend.__module__], extend.__name__, enum_class) setattr(sys.modules[extend.__module__], extend.__name__, enum_class)
return enum.unique(enum_class) return enum.unique(enum_class)
@ -874,6 +874,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
for tuple_item_type, item in zip( for tuple_item_type, item in zip(
(tuple_item_type for tuple_item_type in typing.get_args(typ)), (tuple_item_type for tuple_item_type in typing.get_args(typ)),
(item for item in obj), (item for item in obj),
strict=False,
) )
) )
elif origin_type is Union: elif origin_type is Union:
@ -954,6 +955,7 @@ class RecursiveChecker:
for tuple_item_type, item in zip( for tuple_item_type, item in zip(
(tuple_item_type for tuple_item_type in typing.get_args(typ)), (tuple_item_type for tuple_item_type in typing.get_args(typ)),
(item for item in obj), (item for item in obj),
strict=False,
) )
) )
elif origin_type is Union: elif origin_type is Union:

View file

@ -216,7 +216,7 @@ class TypedTupleSerializer(Serializer[tuple]):
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
def generate(self, obj: tuple) -> List[JsonType]: def generate(self, obj: tuple) -> List[JsonType]:
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)] return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
class CustomSerializer(Serializer): class CustomSerializer(Serializer):