From e8cb9e0adba6485c438bb7cb1e311ac80a90a06c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 18 Feb 2025 16:07:54 -0800 Subject: [PATCH] fix: direct client pydantic type casting (#1145) # What does this PR do? - Closes #1142 - Root cause is due to having `Union[str, AgenToolGroupWithArgs]` ## Test Plan - Test with script described in issue. - Print out final converted pydantic object image [//]: # (## Documentation) --- llama_stack/distribution/library_client.py | 25 ++++++++++++++++------ llama_stack/strong_typing/auxiliary.py | 2 +- llama_stack/strong_typing/classdef.py | 2 +- llama_stack/strong_typing/deserializer.py | 2 +- llama_stack/strong_typing/inspection.py | 6 ++++-- llama_stack/strong_typing/serializer.py | 2 +- 6 files changed, 26 insertions(+), 13 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index a7ef753b9..a40651551 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,7 +13,7 @@ import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Optional, TypeVar, get_args, get_origin +from typing import Any, Optional, TypeVar, Union, get_args, get_origin import httpx import yaml @@ -81,12 +81,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: return value origin = get_origin(annotation) + if origin is list: item_type = get_args(annotation)[0] try: return [convert_to_pydantic(item_type, item) for item in value] except Exception: - print(f"Error converting list {value}") + print(f"Error converting list {value} into {item_type}") return value elif origin is dict: @@ -94,17 +95,26 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} except Exception: - print(f"Error converting dict {value}") + print(f"Error converting dict {value} into {val_type}") return value try: # Handle Pydantic models and discriminated unions return TypeAdapter(annotation).validate_python(value) + except Exception as e: - cprint( - f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", - "yellow", - ) + # TODO: this is workardound for having Union[str, AgentToolGroup] in API schema. + # We should get rid of any non-discriminated unions in the API schema. + if origin is Union: + for union_type in get_args(annotation): + try: + return convert_to_pydantic(union_type, value) + except Exception: + continue + cprint( + f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", + "yellow", + ) return value @@ -421,4 +431,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if param_name in body: value = body.get(param_name) converted_body[param_name] = convert_to_pydantic(param.annotation, value) + return converted_body diff --git a/llama_stack/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py index fd183da18..cf19d6083 100644 --- a/llama_stack/strong_typing/auxiliary.py +++ b/llama_stack/strong_typing/auxiliary.py @@ -77,7 +77,7 @@ def typeannotation( """ def wrap(cls: Type[T]) -> Type[T]: - setattr(cls, "__repr__", _compact_dataclass_repr) + cls.__repr__ = _compact_dataclass_repr if not dataclasses.is_dataclass(cls): cls = dataclasses.dataclass( # type: ignore[call-overload] cls, diff --git a/llama_stack/strong_typing/classdef.py b/llama_stack/strong_typing/classdef.py index d2d8688e4..5ead886d4 100644 --- a/llama_stack/strong_typing/classdef.py +++ b/llama_stack/strong_typing/classdef.py @@ -203,7 +203,7 @@ def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for top-level type definitions") - setattr(type_def.type, "__module__", module.__name__) + type_def.type.__module__ = module.__name__ setattr(module, type_name, type_def.type) return node_to_typedef(module, class_name, top_node).type diff --git a/llama_stack/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py index 4c4ee9d89..fc0f40f83 100644 --- a/llama_stack/strong_typing/deserializer.py +++ b/llama_stack/strong_typing/deserializer.py @@ -325,7 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}" ) - return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data)) + return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False)) class UnionDeserializer(Deserializer): diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py index 69bc15597..8bc313021 100644 --- a/llama_stack/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -263,8 +263,8 @@ def extend_enum( enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore # assign the newly created type to the same module where the extending class is defined - setattr(enum_class, "__module__", extend.__module__) - setattr(enum_class, "__doc__", extend.__doc__) + enum_class.__module__ = extend.__module__ + enum_class.__doc__ = extend.__doc__ setattr(sys.modules[extend.__module__], extend.__name__, enum_class) return enum.unique(enum_class) @@ -874,6 +874,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: @@ -954,6 +955,7 @@ class RecursiveChecker: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: diff --git a/llama_stack/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py index 5e93e4c4d..4ca4a4119 100644 --- a/llama_stack/strong_typing/serializer.py +++ b/llama_stack/strong_typing/serializer.py @@ -216,7 +216,7 @@ class TypedTupleSerializer(Serializer[tuple]): self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) def generate(self, obj: tuple) -> List[JsonType]: - return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)] + return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)] class CustomSerializer(Serializer):