forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? - Closes #1142 - Root cause is due to having `Union[str, AgenToolGroupWithArgs]` ## Test Plan - Test with script described in issue. - Print out final converted pydantic object <img width="1470" alt="image" src="https://github.com/user-attachments/assets/15dc9cd0-f37a-4b91-905f-3fe4f59a08c6" /> [//]: # (## Documentation)
		
			
				
	
	
		
			226 lines
		
	
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			226 lines
		
	
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| """
 | |
| Type-safe data interchange for Python data classes.
 | |
| 
 | |
| :see: https://github.com/hunyadi/strong_typing
 | |
| """
 | |
| 
 | |
| import dataclasses
 | |
| import sys
 | |
| from dataclasses import is_dataclass
 | |
| from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload
 | |
| 
 | |
| if sys.version_info >= (3, 9):
 | |
|     from typing import Annotated as Annotated
 | |
| else:
 | |
|     from typing_extensions import Annotated as Annotated
 | |
| 
 | |
| if sys.version_info >= (3, 10):
 | |
|     from typing import TypeAlias as TypeAlias
 | |
| else:
 | |
|     from typing_extensions import TypeAlias as TypeAlias
 | |
| 
 | |
| if sys.version_info >= (3, 11):
 | |
|     from typing import dataclass_transform as dataclass_transform
 | |
| else:
 | |
|     from typing_extensions import dataclass_transform as dataclass_transform
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| 
 | |
| def _compact_dataclass_repr(obj: object) -> str:
 | |
|     """
 | |
|     Compact data-class representation where positional arguments are used instead of keyword arguments.
 | |
| 
 | |
|     :param obj: A data-class object.
 | |
|     :returns: A string that matches the pattern `Class(arg1, arg2, ...)`.
 | |
|     """
 | |
| 
 | |
|     if is_dataclass(obj):
 | |
|         arglist = ", ".join(repr(getattr(obj, field.name)) for field in dataclasses.fields(obj))
 | |
|         return f"{obj.__class__.__name__}({arglist})"
 | |
|     else:
 | |
|         return obj.__class__.__name__
 | |
| 
 | |
| 
 | |
| class CompactDataClass:
 | |
|     "A data class whose repr() uses positional rather than keyword arguments."
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return _compact_dataclass_repr(self)
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def typeannotation(cls: Type[T], /) -> Type[T]: ...
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ...
 | |
| 
 | |
| 
 | |
| @dataclass_transform(eq_default=True, order_default=False)
 | |
| def typeannotation(
 | |
|     cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
 | |
| ) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
 | |
|     """
 | |
|     Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
 | |
| 
 | |
|     :param cls: The data-class type to transform into a type annotation.
 | |
|     :param eq: Whether to generate functions to support equality comparison.
 | |
|     :param order: Whether to generate functions to support ordering.
 | |
|     :returns: A data-class type, or a wrapper for data-class types.
 | |
|     """
 | |
| 
 | |
|     def wrap(cls: Type[T]) -> Type[T]:
 | |
|         cls.__repr__ = _compact_dataclass_repr
 | |
|         if not dataclasses.is_dataclass(cls):
 | |
|             cls = dataclasses.dataclass(  # type: ignore[call-overload]
 | |
|                 cls,
 | |
|                 init=True,
 | |
|                 repr=False,
 | |
|                 eq=eq,
 | |
|                 order=order,
 | |
|                 unsafe_hash=False,
 | |
|                 frozen=True,
 | |
|             )
 | |
|         return cls
 | |
| 
 | |
|     # see if decorator is used as @typeannotation or @typeannotation()
 | |
|     if cls is None:
 | |
|         # called with parentheses
 | |
|         return wrap
 | |
|     else:
 | |
|         # called without parentheses
 | |
|         return wrap(cls)
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class Alias:
 | |
|     "Alternative name of a property, typically used in JSON serialization."
 | |
| 
 | |
|     name: str
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class Signed:
 | |
|     "Signedness of an integer type."
 | |
| 
 | |
|     is_signed: bool
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class Storage:
 | |
|     "Number of bytes the binary representation of an integer type takes, e.g. 4 bytes for an int32."
 | |
| 
 | |
|     bytes: int
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class IntegerRange:
 | |
|     "Minimum and maximum value of an integer. The range is inclusive."
 | |
| 
 | |
|     minimum: int
 | |
|     maximum: int
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class Precision:
 | |
|     "Precision of a floating-point value."
 | |
| 
 | |
|     significant_digits: int
 | |
|     decimal_digits: int = 0
 | |
| 
 | |
|     @property
 | |
|     def integer_digits(self) -> int:
 | |
|         return self.significant_digits - self.decimal_digits
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class TimePrecision:
 | |
|     """
 | |
|     Precision of a timestamp or time interval.
 | |
| 
 | |
|     :param decimal_digits: Number of fractional digits retained in the sub-seconds field for a timestamp.
 | |
|     """
 | |
| 
 | |
|     decimal_digits: int = 0
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class Length:
 | |
|     "Exact length of a string."
 | |
| 
 | |
|     value: int
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class MinLength:
 | |
|     "Minimum length of a string."
 | |
| 
 | |
|     value: int
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class MaxLength:
 | |
|     "Maximum length of a string."
 | |
| 
 | |
|     value: int
 | |
| 
 | |
| 
 | |
| @typeannotation
 | |
| class SpecialConversion:
 | |
|     "Indicates that the annotated type is subject to custom conversion rules."
 | |
| 
 | |
| 
 | |
| int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
 | |
| int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
 | |
| int32: TypeAlias = Annotated[
 | |
|     int,
 | |
|     Signed(True),
 | |
|     Storage(4),
 | |
|     IntegerRange(-2147483648, 2147483647),
 | |
| ]
 | |
| int64: TypeAlias = Annotated[
 | |
|     int,
 | |
|     Signed(True),
 | |
|     Storage(8),
 | |
|     IntegerRange(-9223372036854775808, 9223372036854775807),
 | |
| ]
 | |
| 
 | |
| uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
 | |
| uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
 | |
| uint32: TypeAlias = Annotated[
 | |
|     int,
 | |
|     Signed(False),
 | |
|     Storage(4),
 | |
|     IntegerRange(0, 4294967295),
 | |
| ]
 | |
| uint64: TypeAlias = Annotated[
 | |
|     int,
 | |
|     Signed(False),
 | |
|     Storage(8),
 | |
|     IntegerRange(0, 18446744073709551615),
 | |
| ]
 | |
| 
 | |
| float32: TypeAlias = Annotated[float, Storage(4)]
 | |
| float64: TypeAlias = Annotated[float, Storage(8)]
 | |
| 
 | |
| # maps globals of type Annotated[T, ...] defined in this module to their string names
 | |
| _auxiliary_types: Dict[object, str] = {}
 | |
| module = sys.modules[__name__]
 | |
| for var in dir(module):
 | |
|     typ = getattr(module, var)
 | |
|     if getattr(typ, "__metadata__", None) is not None:
 | |
|         # type is Annotated[T, ...]
 | |
|         _auxiliary_types[typ] = var
 | |
| 
 | |
| 
 | |
| def get_auxiliary_format(data_type: object) -> Optional[str]:
 | |
|     "Returns the JSON format string corresponding to an auxiliary type."
 | |
| 
 | |
|     return _auxiliary_types.get(data_type)
 |