forked from phoenix-oss/llama-stack-mirror
chore: move all Llama Stack types from llama-models to llama-stack (#1098)
llama-models should have extremely minimal cruft. Its sole purpose should be didactic -- show the simplest implementation of the llama models and document the prompt formats, etc. This PR is the complement to https://github.com/meta-llama/llama-models/pull/279 ## Test Plan Ensure all `llama` CLI `model` sub-commands work: ```bash llama model list llama model download --model-id ... llama model prompt-format -m ... ``` Ran tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/ LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/ LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/ ``` Create a fresh venv `uv venv && source .venv/bin/activate` and run `llama stack build --template fireworks --image-type venv` followed by `llama stack run together --image-type venv` <-- the server runs Also checked that the OpenAPI generator can run and there is no change in the generated files as a result. ```bash cd docs/openapi_generator sh run_openapi_generator.sh ```
This commit is contained in:
parent
c0ee512980
commit
314ee09ae3
138 changed files with 8491 additions and 465 deletions
19
llama_stack/strong_typing/__init__.py
Normal file
19
llama_stack/strong_typing/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
Provides auxiliary services for working with Python type annotations, converting typed data to and from JSON,
|
||||
and generating a JSON schema for a complex type.
|
||||
"""
|
||||
|
||||
__version__ = "0.3.4"
|
||||
__author__ = "Levente Hunyadi"
|
||||
__copyright__ = "Copyright 2021-2024, Levente Hunyadi"
|
||||
__license__ = "MIT"
|
||||
__maintainer__ = "Levente Hunyadi"
|
||||
__status__ = "Production"
|
226
llama_stack/strong_typing/auxiliary.py
Normal file
226
llama_stack/strong_typing/auxiliary.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated as Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated as Annotated
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias as TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias as TypeAlias
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import dataclass_transform as dataclass_transform
|
||||
else:
|
||||
from typing_extensions import dataclass_transform as dataclass_transform
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _compact_dataclass_repr(obj: object) -> str:
|
||||
"""
|
||||
Compact data-class representation where positional arguments are used instead of keyword arguments.
|
||||
|
||||
:param obj: A data-class object.
|
||||
:returns: A string that matches the pattern `Class(arg1, arg2, ...)`.
|
||||
"""
|
||||
|
||||
if is_dataclass(obj):
|
||||
arglist = ", ".join(repr(getattr(obj, field.name)) for field in dataclasses.fields(obj))
|
||||
return f"{obj.__class__.__name__}({arglist})"
|
||||
else:
|
||||
return obj.__class__.__name__
|
||||
|
||||
|
||||
class CompactDataClass:
|
||||
"A data class whose repr() uses positional rather than keyword arguments."
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return _compact_dataclass_repr(self)
|
||||
|
||||
|
||||
@overload
|
||||
def typeannotation(cls: Type[T], /) -> Type[T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ...
|
||||
|
||||
|
||||
@dataclass_transform(eq_default=True, order_default=False)
|
||||
def typeannotation(
|
||||
cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
|
||||
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
||||
"""
|
||||
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
|
||||
|
||||
:param cls: The data-class type to transform into a type annotation.
|
||||
:param eq: Whether to generate functions to support equality comparison.
|
||||
:param order: Whether to generate functions to support ordering.
|
||||
:returns: A data-class type, or a wrapper for data-class types.
|
||||
"""
|
||||
|
||||
def wrap(cls: Type[T]) -> Type[T]:
|
||||
setattr(cls, "__repr__", _compact_dataclass_repr)
|
||||
if not dataclasses.is_dataclass(cls):
|
||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||
cls,
|
||||
init=True,
|
||||
repr=False,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=False,
|
||||
frozen=True,
|
||||
)
|
||||
return cls
|
||||
|
||||
# see if decorator is used as @typeannotation or @typeannotation()
|
||||
if cls is None:
|
||||
# called with parentheses
|
||||
return wrap
|
||||
else:
|
||||
# called without parentheses
|
||||
return wrap(cls)
|
||||
|
||||
|
||||
@typeannotation
|
||||
class Alias:
|
||||
"Alternative name of a property, typically used in JSON serialization."
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
@typeannotation
|
||||
class Signed:
|
||||
"Signedness of an integer type."
|
||||
|
||||
is_signed: bool
|
||||
|
||||
|
||||
@typeannotation
|
||||
class Storage:
|
||||
"Number of bytes the binary representation of an integer type takes, e.g. 4 bytes for an int32."
|
||||
|
||||
bytes: int
|
||||
|
||||
|
||||
@typeannotation
|
||||
class IntegerRange:
|
||||
"Minimum and maximum value of an integer. The range is inclusive."
|
||||
|
||||
minimum: int
|
||||
maximum: int
|
||||
|
||||
|
||||
@typeannotation
|
||||
class Precision:
|
||||
"Precision of a floating-point value."
|
||||
|
||||
significant_digits: int
|
||||
decimal_digits: int = 0
|
||||
|
||||
@property
|
||||
def integer_digits(self) -> int:
|
||||
return self.significant_digits - self.decimal_digits
|
||||
|
||||
|
||||
@typeannotation
|
||||
class TimePrecision:
|
||||
"""
|
||||
Precision of a timestamp or time interval.
|
||||
|
||||
:param decimal_digits: Number of fractional digits retained in the sub-seconds field for a timestamp.
|
||||
"""
|
||||
|
||||
decimal_digits: int = 0
|
||||
|
||||
|
||||
@typeannotation
|
||||
class Length:
|
||||
"Exact length of a string."
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
@typeannotation
|
||||
class MinLength:
|
||||
"Minimum length of a string."
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
@typeannotation
|
||||
class MaxLength:
|
||||
"Maximum length of a string."
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
@typeannotation
|
||||
class SpecialConversion:
|
||||
"Indicates that the annotated type is subject to custom conversion rules."
|
||||
|
||||
|
||||
int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
|
||||
int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
|
||||
int32: TypeAlias = Annotated[
|
||||
int,
|
||||
Signed(True),
|
||||
Storage(4),
|
||||
IntegerRange(-2147483648, 2147483647),
|
||||
]
|
||||
int64: TypeAlias = Annotated[
|
||||
int,
|
||||
Signed(True),
|
||||
Storage(8),
|
||||
IntegerRange(-9223372036854775808, 9223372036854775807),
|
||||
]
|
||||
|
||||
uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
|
||||
uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
|
||||
uint32: TypeAlias = Annotated[
|
||||
int,
|
||||
Signed(False),
|
||||
Storage(4),
|
||||
IntegerRange(0, 4294967295),
|
||||
]
|
||||
uint64: TypeAlias = Annotated[
|
||||
int,
|
||||
Signed(False),
|
||||
Storage(8),
|
||||
IntegerRange(0, 18446744073709551615),
|
||||
]
|
||||
|
||||
float32: TypeAlias = Annotated[float, Storage(4)]
|
||||
float64: TypeAlias = Annotated[float, Storage(8)]
|
||||
|
||||
# maps globals of type Annotated[T, ...] defined in this module to their string names
|
||||
_auxiliary_types: Dict[object, str] = {}
|
||||
module = sys.modules[__name__]
|
||||
for var in dir(module):
|
||||
typ = getattr(module, var)
|
||||
if getattr(typ, "__metadata__", None) is not None:
|
||||
# type is Annotated[T, ...]
|
||||
_auxiliary_types[typ] = var
|
||||
|
||||
|
||||
def get_auxiliary_format(data_type: object) -> Optional[str]:
|
||||
"Returns the JSON format string corresponding to an auxiliary type."
|
||||
|
||||
return _auxiliary_types.get(data_type)
|
440
llama_stack/strong_typing/classdef.py
Normal file
440
llama_stack/strong_typing/classdef.py
Normal file
|
@ -0,0 +1,440 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import ipaddress
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from .auxiliary import (
|
||||
Alias,
|
||||
Annotated,
|
||||
MaxLength,
|
||||
Precision,
|
||||
float32,
|
||||
float64,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
)
|
||||
from .core import JsonType, Schema
|
||||
from .docstring import Docstring, DocstringParam
|
||||
from .inspection import TypeLike
|
||||
from .serialization import json_to_object, object_to_json
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaNode:
|
||||
title: Optional[str]
|
||||
description: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaType(JsonSchemaNode):
|
||||
type: str
|
||||
format: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaBoolean(JsonSchemaType):
|
||||
type: Literal["boolean"]
|
||||
const: Optional[bool]
|
||||
default: Optional[bool]
|
||||
examples: Optional[List[bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaInteger(JsonSchemaType):
|
||||
type: Literal["integer"]
|
||||
const: Optional[int]
|
||||
default: Optional[int]
|
||||
examples: Optional[List[int]]
|
||||
enum: Optional[List[int]]
|
||||
minimum: Optional[int]
|
||||
maximum: Optional[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaNumber(JsonSchemaType):
|
||||
type: Literal["number"]
|
||||
const: Optional[float]
|
||||
default: Optional[float]
|
||||
examples: Optional[List[float]]
|
||||
minimum: Optional[float]
|
||||
maximum: Optional[float]
|
||||
exclusiveMinimum: Optional[float]
|
||||
exclusiveMaximum: Optional[float]
|
||||
multipleOf: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaString(JsonSchemaType):
|
||||
type: Literal["string"]
|
||||
const: Optional[str]
|
||||
default: Optional[str]
|
||||
examples: Optional[List[str]]
|
||||
enum: Optional[List[str]]
|
||||
minLength: Optional[int]
|
||||
maxLength: Optional[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaArray(JsonSchemaType):
|
||||
type: Literal["array"]
|
||||
items: "JsonSchemaAny"
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaObject(JsonSchemaType):
|
||||
type: Literal["object"]
|
||||
properties: Optional[Dict[str, "JsonSchemaAny"]]
|
||||
additionalProperties: Optional[bool]
|
||||
required: Optional[List[str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaRef(JsonSchemaNode):
|
||||
ref: Annotated[str, Alias("$ref")]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaAllOf(JsonSchemaNode):
|
||||
allOf: List["JsonSchemaAny"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaAnyOf(JsonSchemaNode):
|
||||
anyOf: List["JsonSchemaAny"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Discriminator:
|
||||
propertyName: str
|
||||
mapping: Dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaOneOf(JsonSchemaNode):
|
||||
oneOf: List["JsonSchemaAny"]
|
||||
discriminator: Optional[Discriminator]
|
||||
|
||||
|
||||
JsonSchemaAny = Union[
|
||||
JsonSchemaRef,
|
||||
JsonSchemaBoolean,
|
||||
JsonSchemaInteger,
|
||||
JsonSchemaNumber,
|
||||
JsonSchemaString,
|
||||
JsonSchemaArray,
|
||||
JsonSchemaObject,
|
||||
JsonSchemaOneOf,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaTopLevelObject(JsonSchemaObject):
|
||||
schema: Annotated[str, Alias("$schema")]
|
||||
definitions: Optional[Dict[str, JsonSchemaAny]]
|
||||
|
||||
|
||||
def integer_range_to_type(min_value: float, max_value: float) -> type:
|
||||
if min_value >= -(2**15) and max_value < 2**15:
|
||||
return int16
|
||||
elif min_value >= -(2**31) and max_value < 2**31:
|
||||
return int32
|
||||
else:
|
||||
return int64
|
||||
|
||||
|
||||
def enum_safe_name(name: str) -> str:
|
||||
name = re.sub(r"\W", "_", name)
|
||||
is_dunder = name.startswith("__")
|
||||
is_sunder = name.startswith("_") and name.endswith("_")
|
||||
if is_dunder or is_sunder: # provide an alternative for dunder and sunder names
|
||||
name = f"v{name}"
|
||||
return name
|
||||
|
||||
|
||||
def enum_values_to_type(
|
||||
module: types.ModuleType,
|
||||
name: str,
|
||||
values: Dict[str, Any],
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Type[enum.Enum]:
|
||||
enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore
|
||||
|
||||
# assign the newly created type to the same module where the defining class is
|
||||
enum_class.__module__ = module.__name__
|
||||
enum_class.__doc__ = str(Docstring(short_description=title, long_description=description))
|
||||
setattr(module, name, enum_class)
|
||||
|
||||
return enum.unique(enum_class)
|
||||
|
||||
|
||||
def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) -> TypeLike:
|
||||
"""
|
||||
Creates a Python type from a JSON schema.
|
||||
|
||||
:param schema: The JSON schema that the types would correspond to.
|
||||
:param module: The module in which to create the new types.
|
||||
:param class_name: The name assigned to the top-level class.
|
||||
"""
|
||||
|
||||
top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema))
|
||||
if top_node.definitions is not None:
|
||||
for type_name, type_node in top_node.definitions.items():
|
||||
type_def = node_to_typedef(module, type_name, type_node)
|
||||
if type_def.default is not dataclasses.MISSING:
|
||||
raise TypeError("disallowed: `default` for top-level type definitions")
|
||||
|
||||
setattr(type_def.type, "__module__", module.__name__)
|
||||
setattr(module, type_name, type_def.type)
|
||||
|
||||
return node_to_typedef(module, class_name, top_node).type
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeDef:
|
||||
type: TypeLike
|
||||
default: Any = dataclasses.MISSING
|
||||
|
||||
|
||||
def json_to_value(target_type: TypeLike, data: JsonType) -> Any:
|
||||
if data is not None:
|
||||
return json_to_object(target_type, data)
|
||||
else:
|
||||
return dataclasses.MISSING
|
||||
|
||||
|
||||
def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode) -> TypeDef:
|
||||
if isinstance(node, JsonSchemaRef):
|
||||
match_obj = re.match(r"^#/definitions/(\w+)$", node.ref)
|
||||
if not match_obj:
|
||||
raise ValueError(f"invalid reference: {node.ref}")
|
||||
|
||||
type_name = match_obj.group(1)
|
||||
return TypeDef(getattr(module, type_name), dataclasses.MISSING)
|
||||
|
||||
elif isinstance(node, JsonSchemaBoolean):
|
||||
if node.const is not None:
|
||||
return TypeDef(Literal[node.const], dataclasses.MISSING)
|
||||
|
||||
default = json_to_value(bool, node.default)
|
||||
return TypeDef(bool, default)
|
||||
|
||||
elif isinstance(node, JsonSchemaInteger):
|
||||
if node.const is not None:
|
||||
return TypeDef(Literal[node.const], dataclasses.MISSING)
|
||||
|
||||
integer_type: TypeLike
|
||||
if node.format == "int16":
|
||||
integer_type = int16
|
||||
elif node.format == "int32":
|
||||
integer_type = int32
|
||||
elif node.format == "int64":
|
||||
integer_type = int64
|
||||
else:
|
||||
if node.enum is not None:
|
||||
integer_type = integer_range_to_type(min(node.enum), max(node.enum))
|
||||
elif node.minimum is not None and node.maximum is not None:
|
||||
integer_type = integer_range_to_type(node.minimum, node.maximum)
|
||||
else:
|
||||
integer_type = int
|
||||
|
||||
default = json_to_value(integer_type, node.default)
|
||||
return TypeDef(integer_type, default)
|
||||
|
||||
elif isinstance(node, JsonSchemaNumber):
|
||||
if node.const is not None:
|
||||
return TypeDef(Literal[node.const], dataclasses.MISSING)
|
||||
|
||||
number_type: TypeLike
|
||||
if node.format == "float32":
|
||||
number_type = float32
|
||||
elif node.format == "float64":
|
||||
number_type = float64
|
||||
else:
|
||||
if (
|
||||
node.exclusiveMinimum is not None
|
||||
and node.exclusiveMaximum is not None
|
||||
and node.exclusiveMinimum == -node.exclusiveMaximum
|
||||
):
|
||||
integer_digits = round(math.log10(node.exclusiveMaximum))
|
||||
else:
|
||||
integer_digits = None
|
||||
|
||||
if node.multipleOf is not None:
|
||||
decimal_digits = -round(math.log10(node.multipleOf))
|
||||
else:
|
||||
decimal_digits = None
|
||||
|
||||
if integer_digits is not None and decimal_digits is not None:
|
||||
number_type = Annotated[
|
||||
decimal.Decimal,
|
||||
Precision(integer_digits + decimal_digits, decimal_digits),
|
||||
]
|
||||
else:
|
||||
number_type = float
|
||||
|
||||
default = json_to_value(number_type, node.default)
|
||||
return TypeDef(number_type, default)
|
||||
|
||||
elif isinstance(node, JsonSchemaString):
|
||||
if node.const is not None:
|
||||
return TypeDef(Literal[node.const], dataclasses.MISSING)
|
||||
|
||||
string_type: TypeLike
|
||||
if node.format == "date-time":
|
||||
string_type = datetime.datetime
|
||||
elif node.format == "uuid":
|
||||
string_type = uuid.UUID
|
||||
elif node.format == "ipv4":
|
||||
string_type = ipaddress.IPv4Address
|
||||
elif node.format == "ipv6":
|
||||
string_type = ipaddress.IPv6Address
|
||||
|
||||
elif node.enum is not None:
|
||||
string_type = enum_values_to_type(
|
||||
module,
|
||||
context,
|
||||
{enum_safe_name(e): e for e in node.enum},
|
||||
title=node.title,
|
||||
description=node.description,
|
||||
)
|
||||
|
||||
elif node.maxLength is not None:
|
||||
string_type = Annotated[str, MaxLength(node.maxLength)]
|
||||
else:
|
||||
string_type = str
|
||||
|
||||
default = json_to_value(string_type, node.default)
|
||||
return TypeDef(string_type, default)
|
||||
|
||||
elif isinstance(node, JsonSchemaArray):
|
||||
type_def = node_to_typedef(module, context, node.items)
|
||||
if type_def.default is not dataclasses.MISSING:
|
||||
raise TypeError("disallowed: `default` for array element type")
|
||||
list_type = List[(type_def.type,)] # type: ignore
|
||||
return TypeDef(list_type, dataclasses.MISSING)
|
||||
|
||||
elif isinstance(node, JsonSchemaObject):
|
||||
if node.properties is None:
|
||||
return TypeDef(JsonType, dataclasses.MISSING)
|
||||
|
||||
if node.additionalProperties is None or node.additionalProperties is not False:
|
||||
raise TypeError("expected: `additionalProperties` equals `false`")
|
||||
|
||||
required = node.required if node.required is not None else []
|
||||
|
||||
class_name = context
|
||||
|
||||
fields: List[Tuple[str, Any, dataclasses.Field]] = []
|
||||
params: Dict[str, DocstringParam] = {}
|
||||
for prop_name, prop_node in node.properties.items():
|
||||
type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node)
|
||||
if prop_name in required:
|
||||
prop_type = type_def.type
|
||||
else:
|
||||
prop_type = Union[(None, type_def.type)]
|
||||
fields.append((prop_name, prop_type, dataclasses.field(default=type_def.default)))
|
||||
prop_desc = prop_node.title or prop_node.description
|
||||
if prop_desc is not None:
|
||||
params[prop_name] = DocstringParam(prop_name, prop_desc)
|
||||
|
||||
fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING)
|
||||
if sys.version_info >= (3, 12):
|
||||
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
|
||||
else:
|
||||
class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__})
|
||||
class_type.__doc__ = str(
|
||||
Docstring(
|
||||
short_description=node.title,
|
||||
long_description=node.description,
|
||||
params=params,
|
||||
)
|
||||
)
|
||||
setattr(module, class_name, class_type)
|
||||
return TypeDef(class_type, dataclasses.MISSING)
|
||||
|
||||
elif isinstance(node, JsonSchemaOneOf):
|
||||
union_defs = tuple(node_to_typedef(module, context, n) for n in node.oneOf)
|
||||
if any(d.default is not dataclasses.MISSING for d in union_defs):
|
||||
raise TypeError("disallowed: `default` for union member type")
|
||||
union_types = tuple(d.type for d in union_defs)
|
||||
return TypeDef(Union[union_types], dataclasses.MISSING)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaFlatteningOptions:
|
||||
qualified_names: bool = False
|
||||
recursive: bool = False
|
||||
|
||||
|
||||
def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema:
|
||||
top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema))
|
||||
flattener = SchemaFlattener(options)
|
||||
obj = flattener.flatten(top_node)
|
||||
return typing.cast(Schema, object_to_json(obj))
|
||||
|
||||
|
||||
class SchemaFlattener:
|
||||
options: SchemaFlatteningOptions
|
||||
|
||||
def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None:
|
||||
self.options = options or SchemaFlatteningOptions()
|
||||
|
||||
def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject:
|
||||
if source_node.type != "object":
|
||||
return source_node
|
||||
|
||||
source_props = source_node.properties or {}
|
||||
target_props: Dict[str, JsonSchemaAny] = {}
|
||||
|
||||
source_reqs = source_node.required or []
|
||||
target_reqs: List[str] = []
|
||||
|
||||
for name, prop in source_props.items():
|
||||
if not isinstance(prop, JsonSchemaObject):
|
||||
target_props[name] = prop
|
||||
if name in source_reqs:
|
||||
target_reqs.append(name)
|
||||
continue
|
||||
|
||||
if self.options.recursive:
|
||||
obj = self.flatten(prop)
|
||||
else:
|
||||
obj = prop
|
||||
if obj.properties is not None:
|
||||
if self.options.qualified_names:
|
||||
target_props.update((f"{name}.{n}", p) for n, p in obj.properties.items())
|
||||
else:
|
||||
target_props.update(obj.properties.items())
|
||||
if obj.required is not None:
|
||||
if self.options.qualified_names:
|
||||
target_reqs.extend(f"{name}.{n}" for n in obj.required)
|
||||
else:
|
||||
target_reqs.extend(obj.required)
|
||||
|
||||
target_node = copy.copy(source_node)
|
||||
target_node.properties = target_props or None
|
||||
target_node.additionalProperties = False
|
||||
target_node.required = target_reqs or None
|
||||
return target_node
|
46
llama_stack/strong_typing/core.py
Normal file
46
llama_stack/strong_typing/core.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class JsonObject:
|
||||
"Placeholder type for an unrestricted JSON object."
|
||||
|
||||
|
||||
class JsonArray:
|
||||
"Placeholder type for an unrestricted JSON array."
|
||||
|
||||
|
||||
# a JSON type with possible `null` values
|
||||
JsonType = Union[
|
||||
None,
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
Dict[str, "JsonType"],
|
||||
List["JsonType"],
|
||||
]
|
||||
|
||||
# a JSON type that cannot contain `null` values
|
||||
StrictJsonType = Union[
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
Dict[str, "StrictJsonType"],
|
||||
List["StrictJsonType"],
|
||||
]
|
||||
|
||||
# a meta-type that captures the object type in a JSON schema
|
||||
Schema = Dict[str, JsonType]
|
876
llama_stack/strong_typing/deserializer.py
Normal file
876
llama_stack/strong_typing/deserializer.py
Normal file
|
@ -0,0 +1,876 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import abc
|
||||
import base64
|
||||
import dataclasses
|
||||
import datetime
|
||||
import enum
|
||||
import inspect
|
||||
import ipaddress
|
||||
import sys
|
||||
import typing
|
||||
import uuid
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .core import JsonType
|
||||
from .exception import JsonKeyError, JsonTypeError, JsonValueError
|
||||
from .inspection import (
|
||||
TypeLike,
|
||||
create_object,
|
||||
enum_value_types,
|
||||
evaluate_type,
|
||||
get_class_properties,
|
||||
get_class_property,
|
||||
get_resolved_hints,
|
||||
is_dataclass_instance,
|
||||
is_dataclass_type,
|
||||
is_named_tuple_type,
|
||||
is_type_annotated,
|
||||
is_type_literal,
|
||||
is_type_optional,
|
||||
unwrap_annotated_type,
|
||||
unwrap_literal_values,
|
||||
unwrap_optional_type,
|
||||
)
|
||||
from .mapping import python_field_to_json_property
|
||||
from .name import python_type_to_str
|
||||
|
||||
E = TypeVar("E", bound=enum.Enum)
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class Deserializer(abc.ABC, Generic[T]):
|
||||
"Parses a JSON value into a Python type."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
"""
|
||||
Creates auxiliary parsers that this parser is depending on.
|
||||
|
||||
:param context: A module context for evaluating types specified as a string.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse(self, data: JsonType) -> T:
|
||||
"""
|
||||
Parses a JSON value into a Python type.
|
||||
|
||||
:param data: The JSON value to de-serialize.
|
||||
:returns: The Python object that the JSON value de-serializes to.
|
||||
"""
|
||||
|
||||
|
||||
class NoneDeserializer(Deserializer[None]):
|
||||
"Parses JSON `null` values into Python `None`."
|
||||
|
||||
def parse(self, data: JsonType) -> None:
|
||||
if data is not None:
|
||||
raise JsonTypeError(f"`None` type expects JSON `null` but instead received: {data}")
|
||||
return None
|
||||
|
||||
|
||||
class BoolDeserializer(Deserializer[bool]):
|
||||
"Parses JSON `boolean` values into Python `bool` type."
|
||||
|
||||
def parse(self, data: JsonType) -> bool:
|
||||
if not isinstance(data, bool):
|
||||
raise JsonTypeError(f"`bool` type expects JSON `boolean` data but instead received: {data}")
|
||||
return bool(data)
|
||||
|
||||
|
||||
class IntDeserializer(Deserializer[int]):
|
||||
"Parses JSON `number` values into Python `int` type."
|
||||
|
||||
def parse(self, data: JsonType) -> int:
|
||||
if not isinstance(data, int):
|
||||
raise JsonTypeError(f"`int` type expects integer data as JSON `number` but instead received: {data}")
|
||||
return int(data)
|
||||
|
||||
|
||||
class FloatDeserializer(Deserializer[float]):
|
||||
"Parses JSON `number` values into Python `float` type."
|
||||
|
||||
def parse(self, data: JsonType) -> float:
|
||||
if not isinstance(data, float) and not isinstance(data, int):
|
||||
raise JsonTypeError(f"`int` type expects data as JSON `number` but instead received: {data}")
|
||||
return float(data)
|
||||
|
||||
|
||||
class StringDeserializer(Deserializer[str]):
|
||||
"Parses JSON `string` values into Python `str` type."
|
||||
|
||||
def parse(self, data: JsonType) -> str:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`str` type expects JSON `string` data but instead received: {data}")
|
||||
return str(data)
|
||||
|
||||
|
||||
class BytesDeserializer(Deserializer[bytes]):
|
||||
"Parses JSON `string` values of Base64-encoded strings into Python `bytes` type."
|
||||
|
||||
def parse(self, data: JsonType) -> bytes:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`bytes` type expects JSON `string` data but instead received: {data}")
|
||||
return base64.b64decode(data, validate=True)
|
||||
|
||||
|
||||
class DateTimeDeserializer(Deserializer[datetime.datetime]):
|
||||
"Parses JSON `string` values representing timestamps in ISO 8601 format to Python `datetime` with time zone."
|
||||
|
||||
def parse(self, data: JsonType) -> datetime.datetime:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`datetime` type expects JSON `string` data but instead received: {data}")
|
||||
|
||||
if data.endswith("Z"):
|
||||
data = f"{data[:-1]}+00:00" # Python's isoformat() does not support military time zones like "Zulu" for UTC
|
||||
timestamp = datetime.datetime.fromisoformat(data)
|
||||
if timestamp.tzinfo is None:
|
||||
raise JsonValueError(f"timestamp lacks explicit time zone designator: {data}")
|
||||
return timestamp
|
||||
|
||||
|
||||
class DateDeserializer(Deserializer[datetime.date]):
|
||||
"Parses JSON `string` values representing dates in ISO 8601 format to Python `date` type."
|
||||
|
||||
def parse(self, data: JsonType) -> datetime.date:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`date` type expects JSON `string` data but instead received: {data}")
|
||||
|
||||
return datetime.date.fromisoformat(data)
|
||||
|
||||
|
||||
class TimeDeserializer(Deserializer[datetime.time]):
|
||||
"Parses JSON `string` values representing time instances in ISO 8601 format to Python `time` type with time zone."
|
||||
|
||||
def parse(self, data: JsonType) -> datetime.time:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`time` type expects JSON `string` data but instead received: {data}")
|
||||
|
||||
return datetime.time.fromisoformat(data)
|
||||
|
||||
|
||||
class UUIDDeserializer(Deserializer[uuid.UUID]):
|
||||
"Parses JSON `string` values of UUID strings into Python `uuid.UUID` type."
|
||||
|
||||
def parse(self, data: JsonType) -> uuid.UUID:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`UUID` type expects JSON `string` data but instead received: {data}")
|
||||
return uuid.UUID(data)
|
||||
|
||||
|
||||
class IPv4Deserializer(Deserializer[ipaddress.IPv4Address]):
|
||||
"Parses JSON `string` values of IPv4 address strings into Python `ipaddress.IPv4Address` type."
|
||||
|
||||
def parse(self, data: JsonType) -> ipaddress.IPv4Address:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`IPv4Address` type expects JSON `string` data but instead received: {data}")
|
||||
return ipaddress.IPv4Address(data)
|
||||
|
||||
|
||||
class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]):
|
||||
"Parses JSON `string` values of IPv6 address strings into Python `ipaddress.IPv6Address` type."
|
||||
|
||||
def parse(self, data: JsonType) -> ipaddress.IPv6Address:
|
||||
if not isinstance(data, str):
|
||||
raise JsonTypeError(f"`IPv6Address` type expects JSON `string` data but instead received: {data}")
|
||||
return ipaddress.IPv6Address(data)
|
||||
|
||||
|
||||
class ListDeserializer(Deserializer[List[T]]):
|
||||
"Recursively de-serializes a JSON array into a Python `list`."
|
||||
|
||||
item_type: Type[T]
|
||||
item_parser: Deserializer
|
||||
|
||||
def __init__(self, item_type: Type[T]) -> None:
|
||||
self.item_type = item_type
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.item_parser = _get_deserializer(self.item_type, context)
|
||||
|
||||
def parse(self, data: JsonType) -> List[T]:
|
||||
if not isinstance(data, list):
|
||||
type_name = python_type_to_str(self.item_type)
|
||||
raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}")
|
||||
|
||||
return [self.item_parser.parse(item) for item in data]
|
||||
|
||||
|
||||
class DictDeserializer(Deserializer[Dict[K, V]]):
|
||||
"Recursively de-serializes a JSON object into a Python `dict`."
|
||||
|
||||
key_type: Type[K]
|
||||
value_type: Type[V]
|
||||
value_parser: Deserializer[V]
|
||||
|
||||
def __init__(self, key_type: Type[K], value_type: Type[V]) -> None:
|
||||
self.key_type = key_type
|
||||
self.value_type = value_type
|
||||
self._check_key_type()
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.value_parser = _get_deserializer(self.value_type, context)
|
||||
|
||||
def _check_key_type(self) -> None:
|
||||
if self.key_type is str:
|
||||
return
|
||||
|
||||
if issubclass(self.key_type, enum.Enum):
|
||||
value_types = enum_value_types(self.key_type)
|
||||
if len(value_types) != 1:
|
||||
raise JsonTypeError(
|
||||
f"type `{self.container_type}` has invalid key type, "
|
||||
f"enumerations must have a consistent member value type but several types found: {value_types}"
|
||||
)
|
||||
value_type = value_types.pop()
|
||||
if value_type is not str:
|
||||
f"`type `{self.container_type}` has invalid enumeration key type, expected `enum.Enum` with string values"
|
||||
return
|
||||
|
||||
raise JsonTypeError(
|
||||
f"`type `{self.container_type}` has invalid key type, expected `str` or `enum.Enum` with string values"
|
||||
)
|
||||
|
||||
@property
|
||||
def container_type(self) -> str:
|
||||
key_type_name = python_type_to_str(self.key_type)
|
||||
value_type_name = python_type_to_str(self.value_type)
|
||||
return f"Dict[{key_type_name}, {value_type_name}]"
|
||||
|
||||
def parse(self, data: JsonType) -> Dict[K, V]:
|
||||
if not isinstance(data, dict):
|
||||
raise JsonTypeError(
|
||||
f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}"
|
||||
)
|
||||
|
||||
return dict(
|
||||
(self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg]
|
||||
for key, value in data.items()
|
||||
)
|
||||
|
||||
|
||||
class SetDeserializer(Deserializer[Set[T]]):
|
||||
"Recursively de-serializes a JSON list into a Python `set`."
|
||||
|
||||
member_type: Type[T]
|
||||
member_parser: Deserializer
|
||||
|
||||
def __init__(self, member_type: Type[T]) -> None:
|
||||
self.member_type = member_type
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.member_parser = _get_deserializer(self.member_type, context)
|
||||
|
||||
def parse(self, data: JsonType) -> Set[T]:
|
||||
if not isinstance(data, list):
|
||||
type_name = python_type_to_str(self.member_type)
|
||||
raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}")
|
||||
|
||||
return set(self.member_parser.parse(item) for item in data)
|
||||
|
||||
|
||||
class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
|
||||
"Recursively de-serializes a JSON list into a Python `tuple`."
|
||||
|
||||
item_types: Tuple[Type[Any], ...]
|
||||
item_parsers: Tuple[Deserializer[Any], ...]
|
||||
|
||||
def __init__(self, item_types: Tuple[Type[Any], ...]) -> None:
|
||||
self.item_types = item_types
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types)
|
||||
|
||||
@property
|
||||
def container_type(self) -> str:
|
||||
type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types)
|
||||
return f"Tuple[{type_names}]"
|
||||
|
||||
def parse(self, data: JsonType) -> Tuple[Any, ...]:
|
||||
if not isinstance(data, list) or len(data) != len(self.item_parsers):
|
||||
if not isinstance(data, list):
|
||||
raise JsonTypeError(
|
||||
f"type `{self.container_type}` expects JSON `array` data but instead received: {data}"
|
||||
)
|
||||
else:
|
||||
count = len(self.item_parsers)
|
||||
raise JsonValueError(
|
||||
f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}"
|
||||
)
|
||||
|
||||
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data))
|
||||
|
||||
|
||||
class UnionDeserializer(Deserializer):
|
||||
"De-serializes a JSON value (of any type) into a Python union type."
|
||||
|
||||
member_types: Tuple[type, ...]
|
||||
member_parsers: Tuple[Deserializer, ...]
|
||||
|
||||
def __init__(self, member_types: Tuple[type, ...]) -> None:
|
||||
self.member_types = member_types
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types)
|
||||
|
||||
def parse(self, data: JsonType) -> Any:
|
||||
for member_parser in self.member_parsers:
|
||||
# iterate over potential types of discriminated union
|
||||
try:
|
||||
return member_parser.parse(data)
|
||||
except (JsonKeyError, JsonTypeError):
|
||||
# indicates a required field is missing from JSON dict -OR- the data cannot be cast to the expected type,
|
||||
# i.e. we don't have the type that we are looking for
|
||||
continue
|
||||
|
||||
type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types)
|
||||
raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}")
|
||||
|
||||
|
||||
def get_literal_properties(typ: type) -> Set[str]:
|
||||
"Returns the names of all properties in a class that are of a literal type."
|
||||
|
||||
return set(
|
||||
property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type)
|
||||
)
|
||||
|
||||
|
||||
def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
|
||||
"Returns a set of properties with literal type that are common across all specified classes."
|
||||
|
||||
if not types or not all(isinstance(typ, type) for typ in types):
|
||||
return set()
|
||||
|
||||
props = get_literal_properties(types[0])
|
||||
for typ in types[1:]:
|
||||
props = props & get_literal_properties(typ)
|
||||
|
||||
return props
|
||||
|
||||
|
||||
class TaggedUnionDeserializer(Deserializer):
|
||||
"De-serializes a JSON value with one or more disambiguating properties into a Python union type."
|
||||
|
||||
member_types: Tuple[type, ...]
|
||||
disambiguating_properties: Set[str]
|
||||
member_parsers: Dict[Tuple[str, Any], Deserializer]
|
||||
|
||||
def __init__(self, member_types: Tuple[type, ...]) -> None:
|
||||
self.member_types = member_types
|
||||
self.disambiguating_properties = get_discriminating_properties(member_types)
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
self.member_parsers = {}
|
||||
for member_type in self.member_types:
|
||||
for property_name in self.disambiguating_properties:
|
||||
literal_type = get_class_property(member_type, property_name)
|
||||
if not literal_type:
|
||||
continue
|
||||
|
||||
for literal_value in unwrap_literal_values(literal_type):
|
||||
tpl = (property_name, literal_value)
|
||||
if tpl in self.member_parsers:
|
||||
raise JsonTypeError(
|
||||
f"disambiguating property `{property_name}` in type `{self.union_type}` has a duplicate value: {literal_value}"
|
||||
)
|
||||
|
||||
self.member_parsers[tpl] = _get_deserializer(member_type, context)
|
||||
|
||||
@property
|
||||
def union_type(self) -> str:
|
||||
type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types)
|
||||
return f"Union[{type_names}]"
|
||||
|
||||
def parse(self, data: JsonType) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
raise JsonTypeError(
|
||||
f"tagged union type `{self.union_type}` expects JSON `object` data but instead received: {data}"
|
||||
)
|
||||
|
||||
for property_name in self.disambiguating_properties:
|
||||
disambiguating_value = data.get(property_name)
|
||||
if disambiguating_value is None:
|
||||
continue
|
||||
|
||||
member_parser = self.member_parsers.get((property_name, disambiguating_value))
|
||||
if member_parser is None:
|
||||
raise JsonTypeError(
|
||||
f"disambiguating property value is invalid for tagged union type `{self.union_type}`: {data}"
|
||||
)
|
||||
|
||||
return member_parser.parse(data)
|
||||
|
||||
raise JsonTypeError(
|
||||
f"disambiguating property value is missing for tagged union type `{self.union_type}`: {data}"
|
||||
)
|
||||
|
||||
|
||||
class LiteralDeserializer(Deserializer):
|
||||
"De-serializes a JSON value into a Python literal type."
|
||||
|
||||
values: Tuple[Any, ...]
|
||||
parser: Deserializer
|
||||
|
||||
def __init__(self, values: Tuple[Any, ...]) -> None:
|
||||
self.values = values
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
literal_type_tuple = tuple(type(value) for value in self.values)
|
||||
literal_type_set = set(literal_type_tuple)
|
||||
if len(literal_type_set) != 1:
|
||||
value_names = ", ".join(repr(value) for value in self.values)
|
||||
raise TypeError(
|
||||
f"type `Literal[{value_names}]` expects consistent literal value types but got: {literal_type_tuple}"
|
||||
)
|
||||
|
||||
literal_type = literal_type_set.pop()
|
||||
self.parser = _get_deserializer(literal_type, context)
|
||||
|
||||
def parse(self, data: JsonType) -> Any:
|
||||
value = self.parser.parse(data)
|
||||
if value not in self.values:
|
||||
value_names = ", ".join(repr(value) for value in self.values)
|
||||
raise JsonTypeError(f"type `Literal[{value_names}]` could not be instantiated from: {data}")
|
||||
return value
|
||||
|
||||
|
||||
class EnumDeserializer(Deserializer[E]):
|
||||
"Returns an enumeration instance based on the enumeration value read from a JSON value."
|
||||
|
||||
enum_type: Type[E]
|
||||
|
||||
def __init__(self, enum_type: Type[E]) -> None:
|
||||
self.enum_type = enum_type
|
||||
|
||||
def parse(self, data: JsonType) -> E:
|
||||
return self.enum_type(data)
|
||||
|
||||
|
||||
class CustomDeserializer(Deserializer[T]):
|
||||
"Uses the `from_json` class method in class to de-serialize the object from JSON."
|
||||
|
||||
converter: Callable[[JsonType], T]
|
||||
|
||||
def __init__(self, converter: Callable[[JsonType], T]) -> None:
|
||||
self.converter = converter
|
||||
|
||||
def parse(self, data: JsonType) -> T:
|
||||
return self.converter(data)
|
||||
|
||||
|
||||
class FieldDeserializer(abc.ABC, Generic[T, R]):
|
||||
"""
|
||||
Deserializes a JSON property into a Python object field.
|
||||
|
||||
:param property_name: The name of the JSON property to read from a JSON `object`.
|
||||
:param field_name: The name of the field in a Python class to write data to.
|
||||
:param parser: A compatible deserializer that can handle the field's type.
|
||||
"""
|
||||
|
||||
property_name: str
|
||||
field_name: str
|
||||
parser: Deserializer[T]
|
||||
|
||||
def __init__(self, property_name: str, field_name: str, parser: Deserializer[T]) -> None:
|
||||
self.property_name = property_name
|
||||
self.field_name = field_name
|
||||
self.parser = parser
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> R: ...
|
||||
|
||||
|
||||
class RequiredFieldDeserializer(FieldDeserializer[T, T]):
|
||||
"Deserializes a JSON property into a mandatory Python object field."
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
if self.property_name not in data:
|
||||
raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}")
|
||||
|
||||
return self.parser.parse(data[self.property_name])
|
||||
|
||||
|
||||
class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]):
|
||||
"Deserializes a JSON property into an optional Python object field with a default value of `None`."
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class DefaultFieldDeserializer(FieldDeserializer[T, T]):
|
||||
"Deserializes a JSON property into a Python object field with an explicit default value."
|
||||
|
||||
default_value: T
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
property_name: str,
|
||||
field_name: str,
|
||||
parser: Deserializer,
|
||||
default_value: T,
|
||||
) -> None:
|
||||
super().__init__(property_name, field_name, parser)
|
||||
self.default_value = default_value
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
else:
|
||||
return self.default_value
|
||||
|
||||
|
||||
class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]):
|
||||
"Deserializes a JSON property into an optional Python object field with an explicit default value factory."
|
||||
|
||||
default_factory: Callable[[], T]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
property_name: str,
|
||||
field_name: str,
|
||||
parser: Deserializer[T],
|
||||
default_factory: Callable[[], T],
|
||||
) -> None:
|
||||
super().__init__(property_name, field_name, parser)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
else:
|
||||
return self.default_factory()
|
||||
|
||||
|
||||
class ClassDeserializer(Deserializer[T]):
|
||||
"Base class for de-serializing class-like types such as data classes, named tuples and regular classes."
|
||||
|
||||
class_type: type
|
||||
property_parsers: List[FieldDeserializer]
|
||||
property_fields: Set[str]
|
||||
|
||||
def __init__(self, class_type: Type[T]) -> None:
|
||||
self.class_type = class_type
|
||||
|
||||
def assign(self, property_parsers: List[FieldDeserializer]) -> None:
|
||||
self.property_parsers = property_parsers
|
||||
self.property_fields = set(property_parser.property_name for property_parser in property_parsers)
|
||||
|
||||
def parse(self, data: JsonType) -> T:
|
||||
if not isinstance(data, dict):
|
||||
type_name = python_type_to_str(self.class_type)
|
||||
raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}")
|
||||
|
||||
object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data)
|
||||
|
||||
field_values = {}
|
||||
for property_parser in self.property_parsers:
|
||||
field_values[property_parser.field_name] = property_parser.parse_field(object_data)
|
||||
|
||||
if not self.property_fields.issuperset(object_data):
|
||||
unassigned_names = [name for name in object_data if name not in self.property_fields]
|
||||
raise JsonKeyError(f"unrecognized fields in JSON object: {unassigned_names}")
|
||||
|
||||
return self.create(**field_values)
|
||||
|
||||
def create(self, **field_values: Any) -> T:
|
||||
"Instantiates an object with a collection of property values."
|
||||
|
||||
obj: T = create_object(self.class_type)
|
||||
|
||||
# use `setattr` on newly created object instance
|
||||
for field_name, field_value in field_values.items():
|
||||
setattr(obj, field_name, field_value)
|
||||
return obj
|
||||
|
||||
|
||||
class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
|
||||
"De-serializes a named tuple from a JSON `object`."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = [
|
||||
RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context))
|
||||
for field_name, field_type in get_resolved_hints(self.class_type).items()
|
||||
]
|
||||
super().assign(property_parsers)
|
||||
|
||||
def create(self, **field_values: Any) -> NamedTuple:
|
||||
return self.class_type(**field_values)
|
||||
|
||||
|
||||
class DataclassDeserializer(ClassDeserializer[T]):
|
||||
"De-serializes a data class from a JSON `object`."
|
||||
|
||||
def __init__(self, class_type: Type[T]) -> None:
|
||||
if not dataclasses.is_dataclass(class_type):
|
||||
raise TypeError("expected: data-class type")
|
||||
super().__init__(class_type) # type: ignore[arg-type]
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = []
|
||||
resolved_hints = get_resolved_hints(self.class_type)
|
||||
for field in dataclasses.fields(self.class_type):
|
||||
field_type = resolved_hints[field.name]
|
||||
property_name = python_field_to_json_property(field.name, field_type)
|
||||
|
||||
is_optional = is_type_optional(field_type)
|
||||
has_default = field.default is not dataclasses.MISSING
|
||||
has_default_factory = field.default_factory is not dataclasses.MISSING
|
||||
|
||||
if is_optional:
|
||||
required_type: Type[T] = unwrap_optional_type(field_type)
|
||||
else:
|
||||
required_type = field_type
|
||||
|
||||
parser = _get_deserializer(required_type, context)
|
||||
|
||||
if has_default:
|
||||
field_parser: FieldDeserializer = DefaultFieldDeserializer(
|
||||
property_name, field.name, parser, field.default
|
||||
)
|
||||
elif has_default_factory:
|
||||
default_factory = typing.cast(Callable[[], Any], field.default_factory)
|
||||
field_parser = DefaultFactoryFieldDeserializer(property_name, field.name, parser, default_factory)
|
||||
elif is_optional:
|
||||
field_parser = OptionalFieldDeserializer(property_name, field.name, parser)
|
||||
else:
|
||||
field_parser = RequiredFieldDeserializer(property_name, field.name, parser)
|
||||
|
||||
property_parsers.append(field_parser)
|
||||
|
||||
super().assign(property_parsers)
|
||||
|
||||
|
||||
class FrozenDataclassDeserializer(DataclassDeserializer[T]):
|
||||
"De-serializes a frozen data class from a JSON `object`."
|
||||
|
||||
def create(self, **field_values: Any) -> T:
|
||||
"Instantiates an object with a collection of property values."
|
||||
|
||||
# create object instance without calling `__init__`
|
||||
obj: T = create_object(self.class_type)
|
||||
|
||||
# can't use `setattr` on frozen dataclasses, pass member variable values to `__init__`
|
||||
obj.__init__(**field_values) # type: ignore
|
||||
return obj
|
||||
|
||||
|
||||
class TypedClassDeserializer(ClassDeserializer[T]):
|
||||
"De-serializes a class with type annotations from a JSON `object` by iterating over class properties."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = []
|
||||
for field_name, field_type in get_resolved_hints(self.class_type).items():
|
||||
property_name = python_field_to_json_property(field_name, field_type)
|
||||
|
||||
is_optional = is_type_optional(field_type)
|
||||
|
||||
if is_optional:
|
||||
required_type: Type[T] = unwrap_optional_type(field_type)
|
||||
else:
|
||||
required_type = field_type
|
||||
|
||||
parser = _get_deserializer(required_type, context)
|
||||
|
||||
if is_optional:
|
||||
field_parser: FieldDeserializer = OptionalFieldDeserializer(property_name, field_name, parser)
|
||||
else:
|
||||
field_parser = RequiredFieldDeserializer(property_name, field_name, parser)
|
||||
|
||||
property_parsers.append(field_parser)
|
||||
|
||||
super().assign(property_parsers)
|
||||
|
||||
|
||||
def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer:
|
||||
"""
|
||||
Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string.
|
||||
|
||||
When de-serializing a JSON object into a Python object, the following transformations are applied:
|
||||
|
||||
* Fundamental types are parsed as `bool`, `int`, `float` or `str`.
|
||||
* Date and time types are parsed from the ISO 8601 format with time zone into the corresponding Python type
|
||||
`datetime`, `date` or `time`.
|
||||
* Byte arrays are read from a string with Base64 encoding into a `bytes` instance.
|
||||
* UUIDs are extracted from a UUID string compliant with RFC 4122 into a `uuid.UUID` instance.
|
||||
* Enumerations are instantiated with a lookup on enumeration value.
|
||||
* Containers (e.g. `list`, `dict`, `set`, `tuple`) are parsed recursively.
|
||||
* Complex objects with properties (including data class types) are populated from dictionaries of key-value pairs
|
||||
using reflection (enumerating type annotations).
|
||||
|
||||
:raises TypeError: A de-serializer engine cannot be constructed for the input type.
|
||||
"""
|
||||
|
||||
if context is None:
|
||||
if isinstance(typ, type):
|
||||
context = sys.modules[typ.__module__]
|
||||
|
||||
return _get_deserializer(typ, context)
|
||||
|
||||
|
||||
_CACHE: Dict[Tuple[str, str], Deserializer] = {}
|
||||
|
||||
|
||||
def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer:
|
||||
"Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string."
|
||||
|
||||
cache_key = None
|
||||
|
||||
if isinstance(typ, (str, typing.ForwardRef)):
|
||||
if context is None:
|
||||
raise TypeError(f"missing context for evaluating type: {typ}")
|
||||
|
||||
if isinstance(typ, str):
|
||||
if hasattr(context, typ):
|
||||
cache_key = (context.__name__, typ)
|
||||
elif isinstance(typ, typing.ForwardRef):
|
||||
if hasattr(context, typ.__forward_arg__):
|
||||
cache_key = (context.__name__, typ.__forward_arg__)
|
||||
|
||||
typ = evaluate_type(typ, context)
|
||||
|
||||
typ = unwrap_annotated_type(typ) if is_type_annotated(typ) else typ
|
||||
|
||||
if isinstance(typ, type) and typing.get_origin(typ) is None:
|
||||
cache_key = (typ.__module__, typ.__name__)
|
||||
|
||||
if cache_key is not None:
|
||||
deserializer = _CACHE.get(cache_key)
|
||||
if deserializer is None:
|
||||
deserializer = _create_deserializer(typ)
|
||||
|
||||
# store de-serializer immediately in cache to avoid stack overflow for recursive types
|
||||
_CACHE[cache_key] = deserializer
|
||||
|
||||
if isinstance(typ, type):
|
||||
# use type's own module as context for evaluating member types
|
||||
context = sys.modules[typ.__module__]
|
||||
|
||||
# create any de-serializers this de-serializer is depending on
|
||||
deserializer.build(context)
|
||||
else:
|
||||
# special forms are not always hashable, create a new de-serializer every time
|
||||
deserializer = _create_deserializer(typ)
|
||||
deserializer.build(context)
|
||||
|
||||
return deserializer
|
||||
|
||||
|
||||
def _create_deserializer(typ: TypeLike) -> Deserializer:
|
||||
"Creates a de-serializer engine to parse an object obtained from a JSON string."
|
||||
|
||||
# check for well-known types
|
||||
if typ is type(None):
|
||||
return NoneDeserializer()
|
||||
elif typ is bool:
|
||||
return BoolDeserializer()
|
||||
elif typ is int:
|
||||
return IntDeserializer()
|
||||
elif typ is float:
|
||||
return FloatDeserializer()
|
||||
elif typ is str:
|
||||
return StringDeserializer()
|
||||
elif typ is bytes:
|
||||
return BytesDeserializer()
|
||||
elif typ is datetime.datetime:
|
||||
return DateTimeDeserializer()
|
||||
elif typ is datetime.date:
|
||||
return DateDeserializer()
|
||||
elif typ is datetime.time:
|
||||
return TimeDeserializer()
|
||||
elif typ is uuid.UUID:
|
||||
return UUIDDeserializer()
|
||||
elif typ is ipaddress.IPv4Address:
|
||||
return IPv4Deserializer()
|
||||
elif typ is ipaddress.IPv6Address:
|
||||
return IPv6Deserializer()
|
||||
|
||||
# dynamically-typed collection types
|
||||
if typ is list:
|
||||
raise TypeError("explicit item type required: use `List[T]` instead of `list`")
|
||||
if typ is dict:
|
||||
raise TypeError("explicit key and value types required: use `Dict[K, V]` instead of `dict`")
|
||||
if typ is set:
|
||||
raise TypeError("explicit member type required: use `Set[T]` instead of `set`")
|
||||
if typ is tuple:
|
||||
raise TypeError("explicit item type list required: use `Tuple[T, ...]` instead of `tuple`")
|
||||
|
||||
# generic types (e.g. list, dict, set, etc.)
|
||||
origin_type = typing.get_origin(typ)
|
||||
if origin_type is list:
|
||||
(list_item_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return ListDeserializer(list_item_type)
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
return DictDeserializer(key_type, value_type)
|
||||
elif origin_type is set:
|
||||
(set_member_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return SetDeserializer(set_member_type)
|
||||
elif origin_type is tuple:
|
||||
return TupleDeserializer(typing.get_args(typ))
|
||||
elif origin_type is Union:
|
||||
union_args = typing.get_args(typ)
|
||||
if get_discriminating_properties(union_args):
|
||||
return TaggedUnionDeserializer(union_args)
|
||||
else:
|
||||
return UnionDeserializer(union_args)
|
||||
elif origin_type is Literal:
|
||||
return LiteralDeserializer(typing.get_args(typ))
|
||||
|
||||
if not inspect.isclass(typ):
|
||||
if is_dataclass_instance(typ):
|
||||
raise TypeError(f"dataclass type expected but got instance: {typ}")
|
||||
else:
|
||||
raise TypeError(f"unable to de-serialize unrecognized type: {typ}")
|
||||
|
||||
if issubclass(typ, enum.Enum):
|
||||
return EnumDeserializer(typ)
|
||||
|
||||
if is_named_tuple_type(typ):
|
||||
return NamedTupleDeserializer(typ)
|
||||
|
||||
# check if object has custom serialization method
|
||||
convert_func = getattr(typ, "from_json", None)
|
||||
if callable(convert_func):
|
||||
return CustomDeserializer(convert_func)
|
||||
|
||||
if is_dataclass_type(typ):
|
||||
dataclass_params = getattr(typ, "__dataclass_params__", None)
|
||||
if dataclass_params is not None and dataclass_params.frozen:
|
||||
return FrozenDataclassDeserializer(typ)
|
||||
else:
|
||||
return DataclassDeserializer(typ)
|
||||
|
||||
return TypedClassDeserializer(typ)
|
399
llama_stack/strong_typing/docstring.py
Normal file
399
llama_stack/strong_typing/docstring.py
Normal file
|
@ -0,0 +1,399 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import dataclasses
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from .inspection import (
|
||||
DataclassInstance,
|
||||
get_class_properties,
|
||||
get_signature,
|
||||
is_dataclass_type,
|
||||
is_type_enum,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocstringParam:
|
||||
"""
|
||||
A parameter declaration in a parameter block.
|
||||
|
||||
:param name: The name of the parameter.
|
||||
:param description: The description text for the parameter.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
param_type: type = inspect.Signature.empty
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f":param {self.name}: {self.description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocstringReturns:
|
||||
"""
|
||||
A `returns` declaration extracted from a docstring.
|
||||
|
||||
:param description: The description text for the return value.
|
||||
"""
|
||||
|
||||
description: str
|
||||
return_type: type = inspect.Signature.empty
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f":returns: {self.description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocstringRaises:
|
||||
"""
|
||||
A `raises` declaration extracted from a docstring.
|
||||
|
||||
:param typename: The type name of the exception raised.
|
||||
:param description: The description associated with the exception raised.
|
||||
"""
|
||||
|
||||
typename: str
|
||||
description: str
|
||||
raise_type: type = inspect.Signature.empty
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f":raises {self.typename}: {self.description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Docstring:
|
||||
"""
|
||||
Represents the documentation string (a.k.a. docstring) for a type such as a (data) class or function.
|
||||
|
||||
A docstring is broken down into the following components:
|
||||
* A short description, which is the first block of text in the documentation string, and ends with a double
|
||||
newline or a parameter block.
|
||||
* A long description, which is the optional block of text following the short description, and ends with
|
||||
a parameter block.
|
||||
* A parameter block of named parameter and description string pairs in ReST-style.
|
||||
* A `returns` declaration, which adds explanation to the return value.
|
||||
* A `raises` declaration, which adds explanation to the exception type raised by the function on error.
|
||||
|
||||
When the docstring is attached to a data class, it is understood as the documentation string of the class
|
||||
`__init__` method.
|
||||
|
||||
:param short_description: The short description text parsed from a docstring.
|
||||
:param long_description: The long description text parsed from a docstring.
|
||||
:param params: The parameter block extracted from a docstring.
|
||||
:param returns: The returns declaration extracted from a docstring.
|
||||
"""
|
||||
|
||||
short_description: Optional[str] = None
|
||||
long_description: Optional[str] = None
|
||||
params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
|
||||
returns: Optional[DocstringReturns] = None
|
||||
raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def full_description(self) -> Optional[str]:
|
||||
if self.short_description and self.long_description:
|
||||
return f"{self.short_description}\n\n{self.long_description}"
|
||||
elif self.short_description:
|
||||
return self.short_description
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
output = StringIO()
|
||||
|
||||
has_description = self.short_description or self.long_description
|
||||
has_blocks = self.params or self.returns or self.raises
|
||||
|
||||
if has_description:
|
||||
if self.short_description and self.long_description:
|
||||
output.write(self.short_description)
|
||||
output.write("\n\n")
|
||||
output.write(self.long_description)
|
||||
elif self.short_description:
|
||||
output.write(self.short_description)
|
||||
|
||||
if has_blocks:
|
||||
if has_description:
|
||||
output.write("\n")
|
||||
|
||||
for param in self.params.values():
|
||||
output.write("\n")
|
||||
output.write(str(param))
|
||||
if self.returns:
|
||||
output.write("\n")
|
||||
output.write(str(self.returns))
|
||||
for raises in self.raises.values():
|
||||
output.write("\n")
|
||||
output.write(str(raises))
|
||||
|
||||
s = output.getvalue()
|
||||
output.close()
|
||||
return s
|
||||
|
||||
|
||||
def is_exception(member: object) -> TypeGuard[Type[BaseException]]:
|
||||
return isinstance(member, type) and issubclass(member, BaseException)
|
||||
|
||||
|
||||
def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]:
|
||||
"Returns all exception classes declared in a module."
|
||||
|
||||
return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)}
|
||||
|
||||
|
||||
class SupportsDoc(Protocol):
|
||||
__doc__: Optional[str]
|
||||
|
||||
|
||||
def parse_type(typ: SupportsDoc) -> Docstring:
|
||||
"""
|
||||
Parse the docstring of a type into its components.
|
||||
|
||||
:param typ: The type whose documentation string to parse.
|
||||
:returns: Components of the documentation string.
|
||||
"""
|
||||
|
||||
doc = get_docstring(typ)
|
||||
if doc is None:
|
||||
return Docstring()
|
||||
|
||||
docstring = parse_text(doc)
|
||||
check_docstring(typ, docstring)
|
||||
|
||||
# assign parameter and return types
|
||||
if is_dataclass_type(typ):
|
||||
properties = dict(get_class_properties(typing.cast(type, typ)))
|
||||
|
||||
for name, param in docstring.params.items():
|
||||
param.param_type = properties[name]
|
||||
|
||||
elif inspect.isfunction(typ):
|
||||
signature = get_signature(typ)
|
||||
for name, param in docstring.params.items():
|
||||
param.param_type = signature.parameters[name].annotation
|
||||
if docstring.returns:
|
||||
docstring.returns.return_type = signature.return_annotation
|
||||
|
||||
# assign exception types
|
||||
defining_module = inspect.getmodule(typ)
|
||||
if defining_module:
|
||||
context: Dict[str, type] = {}
|
||||
context.update(get_exceptions(builtins))
|
||||
context.update(get_exceptions(defining_module))
|
||||
for exc_name, exc in docstring.raises.items():
|
||||
raise_type = context.get(exc_name)
|
||||
if raise_type is None:
|
||||
type_name = getattr(typ, "__qualname__", None) or getattr(typ, "__name__", None) or None
|
||||
raise TypeError(
|
||||
f"doc-string exception type `{exc_name}` is not an exception defined in the context of `{type_name}`"
|
||||
)
|
||||
|
||||
exc.raise_type = raise_type
|
||||
|
||||
return docstring
|
||||
|
||||
|
||||
def parse_text(text: str) -> Docstring:
|
||||
"""
|
||||
Parse a ReST-style docstring into its components.
|
||||
|
||||
:param text: The documentation string to parse, typically acquired as `type.__doc__`.
|
||||
:returns: Components of the documentation string.
|
||||
"""
|
||||
|
||||
if not text:
|
||||
return Docstring()
|
||||
|
||||
# find block that starts object metadata block (e.g. `:param p:` or `:returns:`)
|
||||
text = inspect.cleandoc(text)
|
||||
match = re.search("^:", text, flags=re.MULTILINE)
|
||||
if match:
|
||||
desc_chunk = text[: match.start()]
|
||||
meta_chunk = text[match.start() :] # noqa: E203
|
||||
else:
|
||||
desc_chunk = text
|
||||
meta_chunk = ""
|
||||
|
||||
# split description text into short and long description
|
||||
parts = desc_chunk.split("\n\n", 1)
|
||||
|
||||
# ensure short description has no newlines
|
||||
short_description = parts[0].strip().replace("\n", " ") or None
|
||||
|
||||
# ensure long description preserves its structure (e.g. preformatted text)
|
||||
if len(parts) > 1:
|
||||
long_description = parts[1].strip() or None
|
||||
else:
|
||||
long_description = None
|
||||
|
||||
params: Dict[str, DocstringParam] = {}
|
||||
raises: Dict[str, DocstringRaises] = {}
|
||||
returns = None
|
||||
for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE):
|
||||
chunk = match.group(0)
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
args_chunk, desc_chunk = chunk.lstrip(":").split(":", 1)
|
||||
args = args_chunk.split()
|
||||
desc = re.sub(r"\s+", " ", desc_chunk.strip())
|
||||
|
||||
if len(args) > 0:
|
||||
kw = args[0]
|
||||
if len(args) == 2:
|
||||
if kw == "param":
|
||||
params[args[1]] = DocstringParam(
|
||||
name=args[1],
|
||||
description=desc,
|
||||
)
|
||||
elif kw == "raise" or kw == "raises":
|
||||
raises[args[1]] = DocstringRaises(
|
||||
typename=args[1],
|
||||
description=desc,
|
||||
)
|
||||
|
||||
elif len(args) == 1:
|
||||
if kw == "return" or kw == "returns":
|
||||
returns = DocstringReturns(description=desc)
|
||||
|
||||
return Docstring(
|
||||
long_description=long_description,
|
||||
short_description=short_description,
|
||||
params=params,
|
||||
returns=returns,
|
||||
raises=raises,
|
||||
)
|
||||
|
||||
|
||||
def has_default_docstring(typ: SupportsDoc) -> bool:
|
||||
"Check if class has the auto-generated string assigned by @dataclass."
|
||||
|
||||
if not isinstance(typ, type):
|
||||
return False
|
||||
|
||||
if is_dataclass_type(typ):
|
||||
return typ.__doc__ is not None and re.match(f"^{re.escape(typ.__name__)}[(].*[)]$", typ.__doc__) is not None
|
||||
|
||||
if is_type_enum(typ):
|
||||
return typ.__doc__ is not None and typ.__doc__ == "An enumeration."
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def has_docstring(typ: SupportsDoc) -> bool:
|
||||
"Check if class has a documentation string other than the auto-generated string assigned by @dataclass."
|
||||
|
||||
if has_default_docstring(typ):
|
||||
return False
|
||||
|
||||
return bool(typ.__doc__)
|
||||
|
||||
|
||||
def get_docstring(typ: SupportsDoc) -> Optional[str]:
|
||||
if typ.__doc__ is None:
|
||||
return None
|
||||
|
||||
if has_default_docstring(typ):
|
||||
return None
|
||||
|
||||
return typ.__doc__
|
||||
|
||||
|
||||
def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False) -> None:
|
||||
"""
|
||||
Verifies the doc-string of a type.
|
||||
|
||||
:raises TypeError: Raised on a mismatch between doc-string parameters, and function or type signature.
|
||||
"""
|
||||
|
||||
if is_dataclass_type(typ):
|
||||
check_dataclass_docstring(typ, docstring, strict)
|
||||
elif inspect.isfunction(typ):
|
||||
check_function_docstring(typ, docstring, strict)
|
||||
|
||||
|
||||
def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
|
||||
"""
|
||||
Verifies the doc-string of a data-class type.
|
||||
|
||||
:param strict: Whether to check if all data-class members have doc-strings.
|
||||
:raises TypeError: Raised on a mismatch between doc-string parameters and data-class members.
|
||||
"""
|
||||
|
||||
if not is_dataclass_type(typ):
|
||||
raise TypeError("not a data-class type")
|
||||
|
||||
properties = dict(get_class_properties(typ))
|
||||
class_name = typ.__name__
|
||||
|
||||
for name in docstring.params:
|
||||
if name not in properties:
|
||||
raise TypeError(f"doc-string parameter `{name}` is not a member of the data-class `{class_name}`")
|
||||
|
||||
if not strict:
|
||||
return
|
||||
|
||||
for name in properties:
|
||||
if name not in docstring.params:
|
||||
raise TypeError(f"member `{name}` in data-class `{class_name}` is missing its doc-string")
|
||||
|
||||
|
||||
def check_function_docstring(fn: Callable[..., Any], docstring: Docstring, strict: bool = False) -> None:
|
||||
"""
|
||||
Verifies the doc-string of a function or member function.
|
||||
|
||||
:param strict: Whether to check if all function parameters and the return type have doc-strings.
|
||||
:raises TypeError: Raised on a mismatch between doc-string parameters and function signature.
|
||||
"""
|
||||
|
||||
signature = get_signature(fn)
|
||||
func_name = fn.__qualname__
|
||||
|
||||
for name in docstring.params:
|
||||
if name not in signature.parameters:
|
||||
raise TypeError(f"doc-string parameter `{name}` is absent from signature of function `{func_name}`")
|
||||
|
||||
if docstring.returns is not None and signature.return_annotation is inspect.Signature.empty:
|
||||
raise TypeError(f"doc-string has returns description in function `{func_name}` with no return type annotation")
|
||||
|
||||
if not strict:
|
||||
return
|
||||
|
||||
for name, param in signature.parameters.items():
|
||||
# ignore `self` in member function signatures
|
||||
if name == "self" and (
|
||||
param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
):
|
||||
continue
|
||||
|
||||
if name not in docstring.params:
|
||||
raise TypeError(f"function parameter `{name}` in `{func_name}` is missing its doc-string")
|
||||
|
||||
if signature.return_annotation is not inspect.Signature.empty and docstring.returns is None:
|
||||
raise TypeError(f"function `{func_name}` has no returns description in its doc-string")
|
23
llama_stack/strong_typing/exception.py
Normal file
23
llama_stack/strong_typing/exception.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
|
||||
class JsonKeyError(Exception):
|
||||
"Raised when deserialization for a class or union type has failed because a matching member was not found."
|
||||
|
||||
|
||||
class JsonValueError(Exception):
|
||||
"Raised when (de)serialization of data has failed due to invalid value."
|
||||
|
||||
|
||||
class JsonTypeError(Exception):
|
||||
"Raised when deserialization of data has failed due to a type mismatch."
|
1034
llama_stack/strong_typing/inspection.py
Normal file
1034
llama_stack/strong_typing/inspection.py
Normal file
File diff suppressed because it is too large
Load diff
40
llama_stack/strong_typing/mapping.py
Normal file
40
llama_stack/strong_typing/mapping.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import keyword
|
||||
from typing import Optional
|
||||
|
||||
from .auxiliary import Alias
|
||||
from .inspection import get_annotation
|
||||
|
||||
|
||||
def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str:
|
||||
"""
|
||||
Map a Python field identifier to a JSON property name.
|
||||
|
||||
Authors may use an underscore appended at the end of a Python identifier as per PEP 8 if it clashes with a Python
|
||||
keyword: e.g. `in` would become `in_` and `from` would become `from_`. Remove these suffixes when exporting to JSON.
|
||||
|
||||
Authors may supply an explicit alias with the type annotation `Alias`, e.g. `Annotated[MyType, Alias("alias")]`.
|
||||
"""
|
||||
|
||||
if python_type is not None:
|
||||
alias = get_annotation(python_type, Alias)
|
||||
if alias:
|
||||
return alias.name
|
||||
|
||||
if python_id.endswith("_"):
|
||||
id = python_id[:-1]
|
||||
if keyword.iskeyword(id):
|
||||
return id
|
||||
|
||||
return python_id
|
182
llama_stack/strong_typing/name.py
Normal file
182
llama_stack/strong_typing/name.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import typing
|
||||
from typing import Any, Literal, Optional, Tuple, Union
|
||||
|
||||
from .auxiliary import _auxiliary_types
|
||||
from .inspection import (
|
||||
TypeLike,
|
||||
is_generic_dict,
|
||||
is_generic_list,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
unwrap_generic_dict,
|
||||
unwrap_generic_list,
|
||||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
|
||||
|
||||
class TypeFormatter:
|
||||
"""
|
||||
Type formatter.
|
||||
|
||||
:param use_union_operator: Whether to emit union types as `X | Y` as per PEP 604.
|
||||
"""
|
||||
|
||||
use_union_operator: bool
|
||||
|
||||
def __init__(self, use_union_operator: bool = False) -> None:
|
||||
self.use_union_operator = use_union_operator
|
||||
|
||||
def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str:
|
||||
if self.use_union_operator:
|
||||
return " | ".join(self.python_type_to_str(t) for t in data_type_args)
|
||||
else:
|
||||
if len(data_type_args) == 2 and type(None) in data_type_args:
|
||||
# Optional[T] is represented as Union[T, None]
|
||||
origin_name = "Optional"
|
||||
data_type_args = tuple(t for t in data_type_args if t is not type(None))
|
||||
else:
|
||||
origin_name = "Union"
|
||||
|
||||
args = ", ".join(self.python_type_to_str(t) for t in data_type_args)
|
||||
return f"{origin_name}[{args}]"
|
||||
|
||||
def plain_type_to_str(self, data_type: TypeLike) -> str:
|
||||
"Returns the string representation of a Python type without metadata."
|
||||
|
||||
# return forward references as the annotation string
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
fwd: typing.ForwardRef = data_type
|
||||
return fwd.__forward_arg__
|
||||
elif isinstance(data_type, str):
|
||||
return data_type
|
||||
|
||||
origin = typing.get_origin(data_type)
|
||||
if origin is not None:
|
||||
data_type_args = typing.get_args(data_type)
|
||||
|
||||
if origin is dict: # Dict[T]
|
||||
origin_name = "Dict"
|
||||
elif origin is list: # List[T]
|
||||
origin_name = "List"
|
||||
elif origin is set: # Set[T]
|
||||
origin_name = "Set"
|
||||
elif origin is Union:
|
||||
return self.union_to_str(data_type_args)
|
||||
elif origin is Literal:
|
||||
args = ", ".join(repr(arg) for arg in data_type_args)
|
||||
return f"Literal[{args}]"
|
||||
else:
|
||||
origin_name = origin.__name__
|
||||
|
||||
args = ", ".join(self.python_type_to_str(t) for t in data_type_args)
|
||||
return f"{origin_name}[{args}]"
|
||||
|
||||
return data_type.__name__
|
||||
|
||||
def python_type_to_str(self, data_type: TypeLike) -> str:
|
||||
"Returns the string representation of a Python type."
|
||||
|
||||
if data_type is type(None):
|
||||
return "None"
|
||||
|
||||
# use compact name for alias types
|
||||
name = _auxiliary_types.get(data_type)
|
||||
if name is not None:
|
||||
return name
|
||||
|
||||
metadata = getattr(data_type, "__metadata__", None)
|
||||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
metatuple: Tuple[Any, ...] = metadata
|
||||
arg = typing.get_args(data_type)[0]
|
||||
|
||||
# check for auxiliary types with user-defined annotations
|
||||
metaset = set(metatuple)
|
||||
for auxiliary_type, auxiliary_name in _auxiliary_types.items():
|
||||
auxiliary_arg = typing.get_args(auxiliary_type)[0]
|
||||
if arg is not auxiliary_arg:
|
||||
continue
|
||||
|
||||
auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None)
|
||||
if auxiliary_metatuple is None:
|
||||
continue
|
||||
|
||||
if metaset.issuperset(auxiliary_metatuple):
|
||||
# type is an auxiliary type with extra annotations
|
||||
auxiliary_args = ", ".join(repr(m) for m in metatuple if m not in auxiliary_metatuple)
|
||||
return f"Annotated[{auxiliary_name}, {auxiliary_args}]"
|
||||
|
||||
# type is an annotated type
|
||||
args = ", ".join(repr(m) for m in metatuple)
|
||||
return f"Annotated[{self.plain_type_to_str(arg)}, {args}]"
|
||||
else:
|
||||
# type is a regular type
|
||||
return self.plain_type_to_str(data_type)
|
||||
|
||||
|
||||
def python_type_to_str(data_type: TypeLike, use_union_operator: bool = False) -> str:
|
||||
"""
|
||||
Returns the string representation of a Python type.
|
||||
|
||||
:param use_union_operator: Whether to emit union types as `X | Y` as per PEP 604.
|
||||
"""
|
||||
|
||||
fmt = TypeFormatter(use_union_operator)
|
||||
return fmt.python_type_to_str(data_type)
|
||||
|
||||
|
||||
def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
|
||||
"""
|
||||
Returns the short name of a Python type.
|
||||
|
||||
:param force: Whether to produce a name for composite types such as generics.
|
||||
"""
|
||||
|
||||
# use compact name for alias types
|
||||
name = _auxiliary_types.get(data_type)
|
||||
if name is not None:
|
||||
return name
|
||||
|
||||
# unwrap annotated types
|
||||
metadata = getattr(data_type, "__metadata__", None)
|
||||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
arg = typing.get_args(data_type)[0]
|
||||
return python_type_to_name(arg)
|
||||
|
||||
if force:
|
||||
# generic types
|
||||
if is_type_optional(data_type, strict=True):
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type))
|
||||
return f"Optional__{inner_name}"
|
||||
elif is_generic_list(data_type):
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type))
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_dict(data_type):
|
||||
key_type, value_type = unwrap_generic_dict(data_type)
|
||||
key_name = python_type_to_name(key_type)
|
||||
value_name = python_type_to_name(value_type)
|
||||
return f"Dict__{key_name}__{value_name}"
|
||||
elif is_type_union(data_type):
|
||||
member_types = unwrap_union_types(data_type)
|
||||
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
|
||||
return f"Union__{member_names}"
|
||||
|
||||
# named system or user-defined type
|
||||
if hasattr(data_type, "__name__") and not typing.get_args(data_type):
|
||||
return data_type.__name__
|
||||
|
||||
raise TypeError(f"cannot assign a simple name to type: {data_type}")
|
0
llama_stack/strong_typing/py.typed
Normal file
0
llama_stack/strong_typing/py.typed
Normal file
752
llama_stack/strong_typing/schema.py
Normal file
752
llama_stack/strong_typing/schema.py
Normal file
|
@ -0,0 +1,752 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import typing
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import jsonschema
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from . import docstring
|
||||
from .auxiliary import (
|
||||
Alias,
|
||||
IntegerRange,
|
||||
MaxLength,
|
||||
MinLength,
|
||||
Precision,
|
||||
get_auxiliary_format,
|
||||
)
|
||||
from .core import JsonArray, JsonObject, JsonType, Schema, StrictJsonType
|
||||
from .inspection import (
|
||||
TypeLike,
|
||||
enum_value_types,
|
||||
get_annotation,
|
||||
get_class_properties,
|
||||
is_type_enum,
|
||||
is_type_like,
|
||||
is_type_optional,
|
||||
unwrap_optional_type,
|
||||
)
|
||||
from .name import python_type_to_name
|
||||
from .serialization import object_to_json
|
||||
|
||||
# determines the maximum number of distinct enum members up to which a Dict[EnumType, Any] is converted into a JSON
|
||||
# schema with explicitly listed properties (rather than employing a pattern constraint on property names)
|
||||
OBJECT_ENUM_EXPANSION_LIMIT = 4
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]:
|
||||
docstr = docstring.parse_type(data_type)
|
||||
|
||||
# check if class has a doc-string other than the auto-generated string assigned by @dataclass
|
||||
if docstring.has_default_docstring(data_type):
|
||||
return None, None
|
||||
|
||||
return docstr.short_description, docstr.long_description
|
||||
|
||||
|
||||
def get_class_property_docstrings(
|
||||
data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Extracts the documentation strings associated with the properties of a composite type.
|
||||
|
||||
:param data_type: The object whose properties to iterate over.
|
||||
:param transform_fun: An optional function that maps a property documentation string to a custom tailored string.
|
||||
:returns: A dictionary mapping property names to descriptions.
|
||||
"""
|
||||
|
||||
result = {}
|
||||
for base in inspect.getmro(data_type):
|
||||
docstr = docstring.parse_type(base)
|
||||
for param in docstr.params.values():
|
||||
if param.name in result:
|
||||
continue
|
||||
|
||||
if transform_fun:
|
||||
description = transform_fun(data_type, param.name, param.description)
|
||||
else:
|
||||
description = param.description
|
||||
|
||||
result[param.name] = description
|
||||
return result
|
||||
|
||||
|
||||
def docstring_to_schema(data_type: type) -> Schema:
|
||||
short_description, long_description = get_class_docstrings(data_type)
|
||||
schema: Schema = {}
|
||||
|
||||
description = "\n".join(filter(None, [short_description, long_description]))
|
||||
if description:
|
||||
schema["description"] = description
|
||||
return schema
|
||||
|
||||
|
||||
def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
|
||||
"Extracts the name of a possibly forward-referenced type."
|
||||
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
forward_type: typing.ForwardRef = data_type
|
||||
return forward_type.__forward_arg__
|
||||
elif isinstance(data_type, str):
|
||||
return data_type
|
||||
else:
|
||||
return data_type.__name__
|
||||
|
||||
|
||||
def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]:
|
||||
"Creates a type from a forward reference."
|
||||
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
forward_type: typing.ForwardRef = data_type
|
||||
true_type = eval(forward_type.__forward_code__)
|
||||
return forward_type.__forward_arg__, true_type
|
||||
elif isinstance(data_type, str):
|
||||
true_type = eval(data_type)
|
||||
return data_type, true_type
|
||||
else:
|
||||
return data_type.__name__, data_type
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeCatalogEntry:
|
||||
schema: Optional[Schema]
|
||||
identifier: str
|
||||
examples: Optional[JsonType] = None
|
||||
|
||||
|
||||
class TypeCatalog:
|
||||
"Maintains an association of well-known Python types to their JSON schema."
|
||||
|
||||
_by_type: Dict[TypeLike, TypeCatalogEntry]
|
||||
_by_name: Dict[str, TypeCatalogEntry]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._by_type = {}
|
||||
self._by_name = {}
|
||||
|
||||
def __contains__(self, data_type: TypeLike) -> bool:
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
fwd: typing.ForwardRef = data_type
|
||||
name = fwd.__forward_arg__
|
||||
return name in self._by_name
|
||||
else:
|
||||
return data_type in self._by_type
|
||||
|
||||
def add(
|
||||
self,
|
||||
data_type: TypeLike,
|
||||
schema: Optional[Schema],
|
||||
identifier: str,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
) -> None:
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
raise TypeError("forward references cannot be used to register a type")
|
||||
|
||||
if data_type in self._by_type:
|
||||
raise ValueError(f"type {data_type} is already registered in the catalog")
|
||||
|
||||
entry = TypeCatalogEntry(schema, identifier, examples)
|
||||
self._by_type[data_type] = entry
|
||||
self._by_name[identifier] = entry
|
||||
|
||||
def get(self, data_type: TypeLike) -> TypeCatalogEntry:
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
fwd: typing.ForwardRef = data_type
|
||||
name = fwd.__forward_arg__
|
||||
return self._by_name[name]
|
||||
else:
|
||||
return self._by_type[data_type]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SchemaOptions:
|
||||
definitions_path: str = "#/definitions/"
|
||||
use_descriptions: bool = True
|
||||
use_examples: bool = True
|
||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
|
||||
|
||||
class JsonSchemaGenerator:
|
||||
"Creates a JSON schema with user-defined type definitions."
|
||||
|
||||
type_catalog: ClassVar[TypeCatalog] = TypeCatalog()
|
||||
types_used: Dict[str, TypeLike]
|
||||
options: SchemaOptions
|
||||
|
||||
def __init__(self, options: Optional[SchemaOptions] = None):
|
||||
if options is None:
|
||||
self.options = SchemaOptions()
|
||||
else:
|
||||
self.options = options
|
||||
self.types_used = {}
|
||||
|
||||
@functools.singledispatchmethod
|
||||
def _metadata_to_schema(self, arg: object) -> Schema:
|
||||
# unrecognized annotation
|
||||
return {}
|
||||
|
||||
@_metadata_to_schema.register
|
||||
def _(self, arg: IntegerRange) -> Schema:
|
||||
return {"minimum": arg.minimum, "maximum": arg.maximum}
|
||||
|
||||
@_metadata_to_schema.register
|
||||
def _(self, arg: Precision) -> Schema:
|
||||
return {
|
||||
"multipleOf": 10 ** (-arg.decimal_digits),
|
||||
"exclusiveMinimum": -(10**arg.integer_digits),
|
||||
"exclusiveMaximum": (10**arg.integer_digits),
|
||||
}
|
||||
|
||||
@_metadata_to_schema.register
|
||||
def _(self, arg: MinLength) -> Schema:
|
||||
return {"minLength": arg.value}
|
||||
|
||||
@_metadata_to_schema.register
|
||||
def _(self, arg: MaxLength) -> Schema:
|
||||
return {"maxLength": arg.value}
|
||||
|
||||
def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema:
|
||||
if metadata:
|
||||
for m in metadata:
|
||||
type_schema.update(self._metadata_to_schema(m))
|
||||
return type_schema
|
||||
|
||||
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]:
|
||||
"""
|
||||
Returns the JSON schema associated with a simple, unrestricted type.
|
||||
|
||||
:returns: The schema for a simple type, or `None`.
|
||||
"""
|
||||
|
||||
if typ is type(None):
|
||||
return {"type": "null"}
|
||||
elif typ is bool:
|
||||
return {"type": "boolean"}
|
||||
elif typ is int:
|
||||
return {"type": "integer"}
|
||||
elif typ is float:
|
||||
return {"type": "number"}
|
||||
elif typ is str:
|
||||
if json_schema_extra and "contentEncoding" in json_schema_extra:
|
||||
return {
|
||||
"type": "string",
|
||||
"contentEncoding": json_schema_extra["contentEncoding"],
|
||||
}
|
||||
return {"type": "string"}
|
||||
elif typ is bytes:
|
||||
return {"type": "string", "contentEncoding": "base64"}
|
||||
elif typ is datetime.datetime:
|
||||
# 2018-11-13T20:20:39+00:00
|
||||
return {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
}
|
||||
elif typ is datetime.date:
|
||||
# 2018-11-13
|
||||
return {"type": "string", "format": "date"}
|
||||
elif typ is datetime.time:
|
||||
# 20:20:39+00:00
|
||||
return {"type": "string", "format": "time"}
|
||||
elif typ is decimal.Decimal:
|
||||
return {"type": "number"}
|
||||
elif typ is uuid.UUID:
|
||||
# f81d4fae-7dec-11d0-a765-00a0c91e6bf6
|
||||
return {"type": "string", "format": "uuid"}
|
||||
elif typ is Any:
|
||||
return {
|
||||
"oneOf": [
|
||||
{"type": "null"},
|
||||
{"type": "boolean"},
|
||||
{"type": "number"},
|
||||
{"type": "string"},
|
||||
{"type": "array"},
|
||||
{"type": "object"},
|
||||
]
|
||||
}
|
||||
elif typ is JsonObject:
|
||||
return {"type": "object"}
|
||||
elif typ is JsonArray:
|
||||
return {"type": "array"}
|
||||
else:
|
||||
# not a simple type
|
||||
return None
|
||||
|
||||
def type_to_schema(
|
||||
self,
|
||||
data_type: TypeLike,
|
||||
force_expand: bool = False,
|
||||
json_schema_extra: Optional[dict] = None,
|
||||
) -> Schema:
|
||||
"""
|
||||
Returns the JSON schema associated with a type.
|
||||
|
||||
:param data_type: The Python type whose JSON schema to return.
|
||||
:param force_expand: Forces a JSON schema to be returned even if the type is registered in the catalog of known types.
|
||||
:returns: The JSON schema associated with the type.
|
||||
"""
|
||||
|
||||
# short-circuit for common simple types
|
||||
schema = self._simple_type_to_schema(data_type, json_schema_extra)
|
||||
if schema is not None:
|
||||
return schema
|
||||
|
||||
# types registered in the type catalog of well-known types
|
||||
type_catalog = JsonSchemaGenerator.type_catalog
|
||||
if not force_expand and data_type in type_catalog:
|
||||
# user-defined type
|
||||
identifier = type_catalog.get(data_type).identifier
|
||||
self.types_used.setdefault(identifier, data_type)
|
||||
return {"$ref": f"{self.options.definitions_path}{identifier}"}
|
||||
|
||||
# unwrap annotated types
|
||||
metadata = getattr(data_type, "__metadata__", None)
|
||||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
typ = typing.get_args(data_type)[0]
|
||||
schema = self._simple_type_to_schema(typ)
|
||||
if schema is not None:
|
||||
# recognize well-known auxiliary types
|
||||
fmt = get_auxiliary_format(data_type)
|
||||
if fmt is not None:
|
||||
schema.update({"format": fmt})
|
||||
return schema
|
||||
else:
|
||||
return self._with_metadata(schema, metadata)
|
||||
|
||||
else:
|
||||
# type is a regular type
|
||||
typ = data_type
|
||||
|
||||
if isinstance(typ, typing.ForwardRef) or isinstance(typ, str):
|
||||
if force_expand:
|
||||
identifier, true_type = type_from_ref(typ)
|
||||
return self.type_to_schema(true_type, force_expand=True)
|
||||
else:
|
||||
try:
|
||||
identifier, true_type = type_from_ref(typ)
|
||||
self.types_used[identifier] = true_type
|
||||
except NameError:
|
||||
identifier = id_from_ref(typ)
|
||||
|
||||
return {"$ref": f"{self.options.definitions_path}{identifier}"}
|
||||
|
||||
if is_type_enum(typ):
|
||||
enum_type: Type[enum.Enum] = typ
|
||||
value_types = enum_value_types(enum_type)
|
||||
if len(value_types) != 1:
|
||||
raise ValueError(
|
||||
f"enumerations must have a consistent member value type but several types found: {value_types}"
|
||||
)
|
||||
enum_value_type = value_types.pop()
|
||||
|
||||
enum_schema: Schema
|
||||
if enum_value_type is bool or enum_value_type is int or enum_value_type is float or enum_value_type is str:
|
||||
if enum_value_type is bool:
|
||||
enum_schema_type = "boolean"
|
||||
elif enum_value_type is int:
|
||||
enum_schema_type = "integer"
|
||||
elif enum_value_type is float:
|
||||
enum_schema_type = "number"
|
||||
elif enum_value_type is str:
|
||||
enum_schema_type = "string"
|
||||
|
||||
enum_schema = {
|
||||
"type": enum_schema_type,
|
||||
"enum": [object_to_json(e.value) for e in enum_type],
|
||||
}
|
||||
if self.options.use_descriptions:
|
||||
enum_schema.update(docstring_to_schema(typ))
|
||||
return enum_schema
|
||||
else:
|
||||
enum_schema = self.type_to_schema(enum_value_type)
|
||||
if self.options.use_descriptions:
|
||||
enum_schema.update(docstring_to_schema(typ))
|
||||
return enum_schema
|
||||
|
||||
origin_type = typing.get_origin(typ)
|
||||
if origin_type is list:
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(list_type)}
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
||||
raise ValueError("`dict` with key type not coercible to `str` is not supported")
|
||||
|
||||
dict_schema: Schema
|
||||
value_schema = self.type_to_schema(value_type)
|
||||
if is_type_enum(key_type):
|
||||
enum_values = [str(e.value) for e in key_type]
|
||||
if len(enum_values) > OBJECT_ENUM_EXPANSION_LIMIT:
|
||||
dict_schema = {
|
||||
"propertyNames": {"pattern": "^(" + "|".join(enum_values) + ")$"},
|
||||
"additionalProperties": value_schema,
|
||||
}
|
||||
else:
|
||||
dict_schema = {
|
||||
"properties": {value: value_schema for value in enum_values},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
else:
|
||||
dict_schema = {"additionalProperties": value_schema}
|
||||
|
||||
schema = {"type": "object"}
|
||||
schema.update(dict_schema)
|
||||
return schema
|
||||
elif origin_type is set:
|
||||
(set_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {
|
||||
"type": "array",
|
||||
"items": self.type_to_schema(set_type),
|
||||
"uniqueItems": True,
|
||||
}
|
||||
elif origin_type is tuple:
|
||||
args = typing.get_args(typ)
|
||||
return {
|
||||
"type": "array",
|
||||
"minItems": len(args),
|
||||
"maxItems": len(args),
|
||||
"prefixItems": [self.type_to_schema(member_type) for member_type in args],
|
||||
}
|
||||
elif origin_type is Union:
|
||||
discriminator = None
|
||||
if typing.get_origin(data_type) is Annotated:
|
||||
discriminator = typing.get_args(data_type)[1].discriminator
|
||||
ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
|
||||
if discriminator:
|
||||
# for each union type, we need to read the value of the discriminator
|
||||
mapping = {}
|
||||
for union_type in typing.get_args(typ):
|
||||
props = self.type_to_schema(union_type, force_expand=True)["properties"]
|
||||
mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"]
|
||||
|
||||
ret["discriminator"] = {
|
||||
"propertyName": discriminator,
|
||||
"mapping": mapping,
|
||||
}
|
||||
return ret
|
||||
elif origin_type is Literal:
|
||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
||||
schema = self.type_to_schema(type(literal_value))
|
||||
schema["const"] = literal_value
|
||||
return schema
|
||||
elif origin_type is type:
|
||||
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
||||
|
||||
# dictionary of class attributes
|
||||
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
|
||||
|
||||
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
|
||||
properties: Dict[str, Schema] = {}
|
||||
required: List[str] = []
|
||||
for property_name, property_type in get_class_properties(typ):
|
||||
# rename property if an alias name is specified
|
||||
alias = get_annotation(property_type, Alias)
|
||||
if alias:
|
||||
output_name = alias.name
|
||||
else:
|
||||
output_name = property_name
|
||||
|
||||
defaults = {}
|
||||
json_schema_extra = None
|
||||
if "model_fields" in members:
|
||||
f = members["model_fields"]
|
||||
defaults = {k: finfo.default for k, finfo in f.items()}
|
||||
json_schema_extra = f.get(output_name, None).json_schema_extra
|
||||
|
||||
if is_type_optional(property_type):
|
||||
optional_type: type = unwrap_optional_type(property_type)
|
||||
property_def = self.type_to_schema(optional_type, json_schema_extra=json_schema_extra)
|
||||
else:
|
||||
property_def = self.type_to_schema(property_type, json_schema_extra=json_schema_extra)
|
||||
required.append(output_name)
|
||||
|
||||
# check if attribute has a default value initializer
|
||||
if defaults.get(property_name) is not None:
|
||||
def_value = defaults[property_name]
|
||||
# check if value can be directly represented in JSON
|
||||
if isinstance(
|
||||
def_value,
|
||||
(
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
enum.Enum,
|
||||
datetime.datetime,
|
||||
datetime.date,
|
||||
datetime.time,
|
||||
),
|
||||
):
|
||||
property_def["default"] = object_to_json(def_value)
|
||||
|
||||
# add property docstring if available
|
||||
property_doc = property_docstrings.get(property_name)
|
||||
if property_doc:
|
||||
# print(output_name, property_doc)
|
||||
property_def.pop("title", None)
|
||||
property_def["description"] = property_doc
|
||||
|
||||
properties[output_name] = property_def
|
||||
|
||||
schema = {"type": "object"}
|
||||
if len(properties) > 0:
|
||||
schema["properties"] = typing.cast(JsonType, properties)
|
||||
schema["additionalProperties"] = False
|
||||
if len(required) > 0:
|
||||
schema["required"] = typing.cast(JsonType, required)
|
||||
if self.options.use_descriptions:
|
||||
schema.update(docstring_to_schema(typ))
|
||||
return schema
|
||||
|
||||
def _type_to_schema_with_lookup(self, data_type: TypeLike) -> Schema:
|
||||
"""
|
||||
Returns the JSON schema associated with a type that may be registered in the catalog of known types.
|
||||
|
||||
:param data_type: The type whose JSON schema we seek.
|
||||
:returns: The JSON schema associated with the type.
|
||||
"""
|
||||
|
||||
entry = JsonSchemaGenerator.type_catalog.get(data_type)
|
||||
if entry.schema is None:
|
||||
type_schema = self.type_to_schema(data_type, force_expand=True)
|
||||
else:
|
||||
type_schema = deepcopy(entry.schema)
|
||||
|
||||
# add descriptive text (if present)
|
||||
if self.options.use_descriptions:
|
||||
if isinstance(data_type, type) and not isinstance(data_type, typing.ForwardRef):
|
||||
type_schema.update(docstring_to_schema(data_type))
|
||||
|
||||
# add example (if present)
|
||||
if self.options.use_examples and entry.examples:
|
||||
type_schema["examples"] = entry.examples
|
||||
|
||||
return type_schema
|
||||
|
||||
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]:
|
||||
"""
|
||||
Returns the JSON schema associated with a type and any nested types.
|
||||
|
||||
:param data_type: The type whose JSON schema to return.
|
||||
:param force_expand: True if a full JSON schema is to be returned even for well-known types; false if a schema
|
||||
reference is to be used for well-known types.
|
||||
:returns: A tuple of the JSON schema, and a mapping between nested type names and their corresponding schema.
|
||||
"""
|
||||
|
||||
if not is_type_like(data_type):
|
||||
raise TypeError(f"expected a type-like object but got: {data_type}")
|
||||
|
||||
self.types_used = {}
|
||||
try:
|
||||
type_schema = self.type_to_schema(data_type, force_expand=force_expand)
|
||||
|
||||
types_defined: Dict[str, Schema] = {}
|
||||
while len(self.types_used) > len(types_defined):
|
||||
# make a snapshot copy; original collection is going to be modified
|
||||
types_undefined = {
|
||||
sub_name: sub_type
|
||||
for sub_name, sub_type in self.types_used.items()
|
||||
if sub_name not in types_defined
|
||||
}
|
||||
|
||||
# expand undefined types, which may lead to additional types to be defined
|
||||
for sub_name, sub_type in types_undefined.items():
|
||||
types_defined[sub_name] = self._type_to_schema_with_lookup(sub_type)
|
||||
|
||||
type_definitions = dict(sorted(types_defined.items()))
|
||||
finally:
|
||||
self.types_used = {}
|
||||
|
||||
return type_schema, type_definitions
|
||||
|
||||
|
||||
class Validator(enum.Enum):
|
||||
"Defines constants for JSON schema standards."
|
||||
|
||||
Draft7 = jsonschema.Draft7Validator
|
||||
Draft201909 = jsonschema.Draft201909Validator
|
||||
Draft202012 = jsonschema.Draft202012Validator
|
||||
Latest = jsonschema.Draft202012Validator
|
||||
|
||||
|
||||
def classdef_to_schema(
|
||||
data_type: TypeLike,
|
||||
options: Optional[SchemaOptions] = None,
|
||||
validator: Validator = Validator.Latest,
|
||||
) -> Schema:
|
||||
"""
|
||||
Returns the JSON schema corresponding to the given type.
|
||||
|
||||
:param data_type: The Python type used to generate the JSON schema
|
||||
:returns: A JSON object that you can serialize to a JSON string with json.dump or json.dumps
|
||||
:raises TypeError: Indicates that the generated JSON schema does not validate against the desired meta-schema.
|
||||
"""
|
||||
|
||||
# short-circuit with an error message when passing invalid data
|
||||
if not is_type_like(data_type):
|
||||
raise TypeError(f"expected a type-like object but got: {data_type}")
|
||||
|
||||
generator = JsonSchemaGenerator(options)
|
||||
type_schema, type_definitions = generator.classdef_to_schema(data_type)
|
||||
|
||||
class_schema: Schema = {}
|
||||
if type_definitions:
|
||||
class_schema["definitions"] = typing.cast(JsonType, type_definitions)
|
||||
class_schema.update(type_schema)
|
||||
|
||||
validator_id = validator.value.META_SCHEMA["$id"]
|
||||
try:
|
||||
validator.value.check_schema(class_schema)
|
||||
except jsonschema.exceptions.SchemaError:
|
||||
raise TypeError(f"schema does not validate against meta-schema <{validator_id}>")
|
||||
|
||||
schema = {"$schema": validator_id}
|
||||
schema.update(class_schema)
|
||||
return schema
|
||||
|
||||
|
||||
def validate_object(data_type: TypeLike, json_dict: JsonType) -> None:
|
||||
"""
|
||||
Validates if the JSON dictionary object conforms to the expected type.
|
||||
|
||||
:param data_type: The type to match against.
|
||||
:param json_dict: A JSON object obtained with `json.load` or `json.loads`.
|
||||
:raises jsonschema.exceptions.ValidationError: Indicates that the JSON object cannot represent the type.
|
||||
"""
|
||||
|
||||
schema_dict = classdef_to_schema(data_type)
|
||||
jsonschema.validate(json_dict, schema_dict, format_checker=jsonschema.FormatChecker())
|
||||
|
||||
|
||||
def print_schema(data_type: type) -> None:
|
||||
"""Pretty-prints the JSON schema corresponding to the type."""
|
||||
|
||||
s = classdef_to_schema(data_type)
|
||||
print(json.dumps(s, indent=4))
|
||||
|
||||
|
||||
def get_schema_identifier(data_type: type) -> Optional[str]:
|
||||
if data_type in JsonSchemaGenerator.type_catalog:
|
||||
return JsonSchemaGenerator.type_catalog.get(data_type).identifier
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def register_schema(
|
||||
data_type: T,
|
||||
schema: Optional[Schema] = None,
|
||||
name: Optional[str] = None,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Associates a type with a JSON schema definition.
|
||||
|
||||
:param data_type: The type to associate with a JSON schema.
|
||||
:param schema: The schema to associate the type with. Derived automatically if omitted.
|
||||
:param name: The name used for looking uo the type. Determined automatically if omitted.
|
||||
:returns: The input type.
|
||||
"""
|
||||
|
||||
JsonSchemaGenerator.type_catalog.add(
|
||||
data_type,
|
||||
schema,
|
||||
name if name is not None else python_type_to_name(data_type),
|
||||
examples,
|
||||
)
|
||||
return data_type
|
||||
|
||||
|
||||
@overload
|
||||
def json_schema_type(cls: Type[T], /) -> Type[T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ...
|
||||
|
||||
|
||||
def json_schema_type(
|
||||
cls: Optional[Type[T]] = None,
|
||||
*,
|
||||
schema: Optional[Schema] = None,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
||||
"""Decorator to add user-defined schema definition to a class."""
|
||||
|
||||
def wrap(cls: Type[T]) -> Type[T]:
|
||||
return register_schema(cls, schema, examples=examples)
|
||||
|
||||
# see if decorator is used as @json_schema_type or @json_schema_type()
|
||||
if cls is None:
|
||||
# called with parentheses
|
||||
return wrap
|
||||
else:
|
||||
# called as @json_schema_type without parentheses
|
||||
return wrap(cls)
|
||||
|
||||
|
||||
register_schema(JsonObject, name="JsonObject")
|
||||
register_schema(JsonArray, name="JsonArray")
|
||||
|
||||
register_schema(
|
||||
JsonType,
|
||||
name="JsonType",
|
||||
examples=[
|
||||
{
|
||||
"property1": None,
|
||||
"property2": True,
|
||||
"property3": 64,
|
||||
"property4": "string",
|
||||
"property5": ["item"],
|
||||
"property6": {"key": "value"},
|
||||
}
|
||||
],
|
||||
)
|
||||
register_schema(
|
||||
StrictJsonType,
|
||||
name="StrictJsonType",
|
||||
examples=[
|
||||
{
|
||||
"property1": True,
|
||||
"property2": 64,
|
||||
"property3": "string",
|
||||
"property4": ["item"],
|
||||
"property5": {"key": "value"},
|
||||
}
|
||||
],
|
||||
)
|
97
llama_stack/strong_typing/serialization.py
Normal file
97
llama_stack/strong_typing/serialization.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional, TextIO, TypeVar
|
||||
|
||||
from .core import JsonType
|
||||
from .deserializer import create_deserializer
|
||||
from .inspection import TypeLike
|
||||
from .serializer import create_serializer
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def object_to_json(obj: Any) -> JsonType:
|
||||
"""
|
||||
Converts a Python object to a representation that can be exported to JSON.
|
||||
|
||||
* Fundamental types (e.g. numeric types) are written as is.
|
||||
* Date and time types are serialized in the ISO 8601 format with time zone.
|
||||
* A byte array is written as a string with Base64 encoding.
|
||||
* UUIDs are written as a UUID string.
|
||||
* Enumerations are written as their value.
|
||||
* Containers (e.g. `list`, `dict`, `set`, `tuple`) are exported recursively.
|
||||
* Objects with properties (including data class types) are converted to a dictionaries of key-value pairs.
|
||||
"""
|
||||
|
||||
typ: type = type(obj)
|
||||
generator = create_serializer(typ)
|
||||
return generator.generate(obj)
|
||||
|
||||
|
||||
def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object:
|
||||
"""
|
||||
Creates an object from a representation that has been de-serialized from JSON.
|
||||
|
||||
When de-serializing a JSON object into a Python object, the following transformations are applied:
|
||||
|
||||
* Fundamental types are parsed as `bool`, `int`, `float` or `str`.
|
||||
* Date and time types are parsed from the ISO 8601 format with time zone into the corresponding Python type
|
||||
`datetime`, `date` or `time`
|
||||
* A byte array is read from a string with Base64 encoding into a `bytes` instance.
|
||||
* UUIDs are extracted from a UUID string into a `uuid.UUID` instance.
|
||||
* Enumerations are instantiated with a lookup on enumeration value.
|
||||
* Containers (e.g. `list`, `dict`, `set`, `tuple`) are parsed recursively.
|
||||
* Complex objects with properties (including data class types) are populated from dictionaries of key-value pairs
|
||||
using reflection (enumerating type annotations).
|
||||
|
||||
:raises TypeError: A de-serializing engine cannot be constructed for the input type.
|
||||
:raises JsonKeyError: Deserialization for a class or union type has failed because a matching member was not found.
|
||||
:raises JsonTypeError: Deserialization for data has failed due to a type mismatch.
|
||||
"""
|
||||
|
||||
# use caller context for evaluating types if no context is supplied
|
||||
if context is None:
|
||||
this_frame = inspect.currentframe()
|
||||
if this_frame is not None:
|
||||
caller_frame = this_frame.f_back
|
||||
del this_frame
|
||||
|
||||
if caller_frame is not None:
|
||||
try:
|
||||
context = sys.modules[caller_frame.f_globals["__name__"]]
|
||||
finally:
|
||||
del caller_frame
|
||||
|
||||
parser = create_deserializer(typ, context)
|
||||
return parser.parse(data)
|
||||
|
||||
|
||||
def json_dump_string(json_object: JsonType) -> str:
|
||||
"Dump an object as a JSON string with a compact representation."
|
||||
|
||||
return json.dumps(json_object, ensure_ascii=False, check_circular=False, separators=(",", ":"))
|
||||
|
||||
|
||||
def json_dump(json_object: JsonType, file: TextIO) -> None:
|
||||
json.dump(
|
||||
json_object,
|
||||
file,
|
||||
ensure_ascii=False,
|
||||
check_circular=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
file.write("\n")
|
497
llama_stack/strong_typing/serializer.py
Normal file
497
llama_stack/strong_typing/serializer.py
Normal file
|
@ -0,0 +1,497 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
import abc
|
||||
import base64
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import ipaddress
|
||||
import sys
|
||||
import typing
|
||||
import uuid
|
||||
from types import FunctionType, MethodType, ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .core import JsonType
|
||||
from .exception import JsonTypeError, JsonValueError
|
||||
from .inspection import (
|
||||
TypeLike,
|
||||
enum_value_types,
|
||||
evaluate_type,
|
||||
get_class_properties,
|
||||
get_resolved_hints,
|
||||
is_dataclass_type,
|
||||
is_named_tuple_type,
|
||||
is_reserved_property,
|
||||
is_type_annotated,
|
||||
is_type_enum,
|
||||
unwrap_annotated_type,
|
||||
)
|
||||
from .mapping import python_field_to_json_property
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Serializer(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def generate(self, data: T) -> JsonType: ...
|
||||
|
||||
|
||||
class NoneSerializer(Serializer[None]):
|
||||
def generate(self, data: None) -> None:
|
||||
# can be directly represented in JSON
|
||||
return None
|
||||
|
||||
|
||||
class BoolSerializer(Serializer[bool]):
|
||||
def generate(self, data: bool) -> bool:
|
||||
# can be directly represented in JSON
|
||||
return data
|
||||
|
||||
|
||||
class IntSerializer(Serializer[int]):
|
||||
def generate(self, data: int) -> int:
|
||||
# can be directly represented in JSON
|
||||
return data
|
||||
|
||||
|
||||
class FloatSerializer(Serializer[float]):
|
||||
def generate(self, data: float) -> float:
|
||||
# can be directly represented in JSON
|
||||
return data
|
||||
|
||||
|
||||
class StringSerializer(Serializer[str]):
|
||||
def generate(self, data: str) -> str:
|
||||
# can be directly represented in JSON
|
||||
return data
|
||||
|
||||
|
||||
class BytesSerializer(Serializer[bytes]):
|
||||
def generate(self, data: bytes) -> str:
|
||||
return base64.b64encode(data).decode("ascii")
|
||||
|
||||
|
||||
class DateTimeSerializer(Serializer[datetime.datetime]):
|
||||
def generate(self, obj: datetime.datetime) -> str:
|
||||
if obj.tzinfo is None:
|
||||
raise JsonValueError(f"timestamp lacks explicit time zone designator: {obj}")
|
||||
fmt = obj.isoformat()
|
||||
if fmt.endswith("+00:00"):
|
||||
fmt = f"{fmt[:-6]}Z" # Python's isoformat() does not support military time zones like "Zulu" for UTC
|
||||
return fmt
|
||||
|
||||
|
||||
class DateSerializer(Serializer[datetime.date]):
|
||||
def generate(self, obj: datetime.date) -> str:
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
class TimeSerializer(Serializer[datetime.time]):
|
||||
def generate(self, obj: datetime.time) -> str:
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
class UUIDSerializer(Serializer[uuid.UUID]):
|
||||
def generate(self, obj: uuid.UUID) -> str:
|
||||
return str(obj)
|
||||
|
||||
|
||||
class IPv4Serializer(Serializer[ipaddress.IPv4Address]):
|
||||
def generate(self, obj: ipaddress.IPv4Address) -> str:
|
||||
return str(obj)
|
||||
|
||||
|
||||
class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
|
||||
def generate(self, obj: ipaddress.IPv6Address) -> str:
|
||||
return str(obj)
|
||||
|
||||
|
||||
class EnumSerializer(Serializer[enum.Enum]):
|
||||
def generate(self, obj: enum.Enum) -> Union[int, str]:
|
||||
return obj.value
|
||||
|
||||
|
||||
class UntypedListSerializer(Serializer[list]):
|
||||
def generate(self, obj: list) -> List[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class UntypedDictSerializer(Serializer[dict]):
|
||||
def generate(self, obj: dict) -> Dict[str, JsonType]:
|
||||
if obj and isinstance(next(iter(obj.keys())), enum.Enum):
|
||||
iterator = ((key.value, object_to_json(value)) for key, value in obj.items())
|
||||
else:
|
||||
iterator = ((str(key), object_to_json(value)) for key, value in obj.items())
|
||||
return dict(iterator)
|
||||
|
||||
|
||||
class UntypedSetSerializer(Serializer[set]):
|
||||
def generate(self, obj: set) -> List[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class UntypedTupleSerializer(Serializer[tuple]):
|
||||
def generate(self, obj: tuple) -> List[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class TypedCollectionSerializer(Serializer, Generic[T]):
|
||||
generator: Serializer[T]
|
||||
|
||||
def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
self.generator = _get_serializer(item_type, context)
|
||||
|
||||
|
||||
class TypedListSerializer(TypedCollectionSerializer[T]):
|
||||
def generate(self, obj: List[T]) -> List[JsonType]:
|
||||
return [self.generator.generate(item) for item in obj]
|
||||
|
||||
|
||||
class TypedStringDictSerializer(TypedCollectionSerializer[T]):
|
||||
def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
super().__init__(value_type, context)
|
||||
|
||||
def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]:
|
||||
return {key: self.generator.generate(value) for key, value in obj.items()}
|
||||
|
||||
|
||||
class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
|
||||
def __init__(
|
||||
self,
|
||||
key_type: Type[enum.Enum],
|
||||
value_type: Type[T],
|
||||
context: Optional[ModuleType],
|
||||
) -> None:
|
||||
super().__init__(value_type, context)
|
||||
|
||||
value_types = enum_value_types(key_type)
|
||||
if len(value_types) != 1:
|
||||
raise JsonTypeError(
|
||||
f"invalid key type, enumerations must have a consistent member value type but several types found: {value_types}"
|
||||
)
|
||||
|
||||
value_type = value_types.pop()
|
||||
if value_type is not str:
|
||||
raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values")
|
||||
|
||||
def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]:
|
||||
return {key.value: self.generator.generate(value) for key, value in obj.items()}
|
||||
|
||||
|
||||
class TypedSetSerializer(TypedCollectionSerializer[T]):
|
||||
def generate(self, obj: Set[T]) -> JsonType:
|
||||
return [self.generator.generate(item) for item in obj]
|
||||
|
||||
|
||||
class TypedTupleSerializer(Serializer[tuple]):
|
||||
item_generators: Tuple[Serializer, ...]
|
||||
|
||||
def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None:
|
||||
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
|
||||
|
||||
def generate(self, obj: tuple) -> List[JsonType]:
|
||||
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)]
|
||||
|
||||
|
||||
class CustomSerializer(Serializer):
|
||||
converter: Callable[[object], JsonType]
|
||||
|
||||
def __init__(self, converter: Callable[[object], JsonType]) -> None:
|
||||
self.converter = converter
|
||||
|
||||
def generate(self, obj: object) -> JsonType:
|
||||
return self.converter(obj)
|
||||
|
||||
|
||||
class FieldSerializer(Generic[T]):
|
||||
"""
|
||||
Serializes a Python object field into a JSON property.
|
||||
|
||||
:param field_name: The name of the field in a Python class to read data from.
|
||||
:param property_name: The name of the JSON property to write to a JSON `object`.
|
||||
:param generator: A compatible serializer that can handle the field's type.
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
property_name: str
|
||||
generator: Serializer
|
||||
|
||||
def __init__(self, field_name: str, property_name: str, generator: Serializer[T]) -> None:
|
||||
self.field_name = field_name
|
||||
self.property_name = property_name
|
||||
self.generator = generator
|
||||
|
||||
def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None:
|
||||
value = getattr(obj, self.field_name)
|
||||
if value is not None:
|
||||
object_dict[self.property_name] = self.generator.generate(value)
|
||||
|
||||
|
||||
class TypedClassSerializer(Serializer[T]):
|
||||
property_generators: List[FieldSerializer]
|
||||
|
||||
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
self.property_generators = [
|
||||
FieldSerializer(
|
||||
field_name,
|
||||
python_field_to_json_property(field_name, field_type),
|
||||
_get_serializer(field_type, context),
|
||||
)
|
||||
for field_name, field_type in get_class_properties(class_type)
|
||||
]
|
||||
|
||||
def generate(self, obj: T) -> Dict[str, JsonType]:
|
||||
object_dict: Dict[str, JsonType] = {}
|
||||
for property_generator in self.property_generators:
|
||||
property_generator.generate_field(obj, object_dict)
|
||||
|
||||
return object_dict
|
||||
|
||||
|
||||
class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]):
|
||||
def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None:
|
||||
super().__init__(class_type, context)
|
||||
|
||||
|
||||
class DataclassSerializer(TypedClassSerializer[T]):
|
||||
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
super().__init__(class_type, context)
|
||||
|
||||
|
||||
class UnionSerializer(Serializer):
|
||||
def generate(self, obj: Any) -> JsonType:
|
||||
return object_to_json(obj)
|
||||
|
||||
|
||||
class LiteralSerializer(Serializer):
|
||||
generator: Serializer
|
||||
|
||||
def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None:
|
||||
literal_type_tuple = tuple(type(value) for value in values)
|
||||
literal_type_set = set(literal_type_tuple)
|
||||
if len(literal_type_set) != 1:
|
||||
value_names = ", ".join(repr(value) for value in values)
|
||||
raise TypeError(
|
||||
f"type `Literal[{value_names}]` expects consistent literal value types but got: {literal_type_tuple}"
|
||||
)
|
||||
|
||||
literal_type = literal_type_set.pop()
|
||||
self.generator = _get_serializer(literal_type, context)
|
||||
|
||||
def generate(self, obj: Any) -> JsonType:
|
||||
return self.generator.generate(obj)
|
||||
|
||||
|
||||
class UntypedNamedTupleSerializer(Serializer):
|
||||
fields: Dict[str, str]
|
||||
|
||||
def __init__(self, class_type: Type[NamedTuple]) -> None:
|
||||
# named tuples are also instances of tuple
|
||||
self.fields = {}
|
||||
field_names: Tuple[str, ...] = class_type._fields
|
||||
for field_name in field_names:
|
||||
self.fields[field_name] = python_field_to_json_property(field_name)
|
||||
|
||||
def generate(self, obj: NamedTuple) -> JsonType:
|
||||
object_dict = {}
|
||||
for field_name, property_name in self.fields.items():
|
||||
value = getattr(obj, field_name)
|
||||
object_dict[property_name] = object_to_json(value)
|
||||
|
||||
return object_dict
|
||||
|
||||
|
||||
class UntypedClassSerializer(Serializer):
|
||||
def generate(self, obj: object) -> JsonType:
|
||||
# iterate over object attributes to get a standard representation
|
||||
object_dict = {}
|
||||
for name in dir(obj):
|
||||
if is_reserved_property(name):
|
||||
continue
|
||||
|
||||
value = getattr(obj, name)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# filter instance methods
|
||||
if inspect.ismethod(value):
|
||||
continue
|
||||
|
||||
object_dict[python_field_to_json_property(name)] = object_to_json(value)
|
||||
|
||||
return object_dict
|
||||
|
||||
|
||||
def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer:
|
||||
"""
|
||||
Creates a serializer engine to produce an object that can be directly converted into a JSON string.
|
||||
|
||||
When serializing a Python object into a JSON object, the following transformations are applied:
|
||||
|
||||
* Fundamental types (`bool`, `int`, `float` or `str`) are returned as-is.
|
||||
* Date and time types (`datetime`, `date` or `time`) produce an ISO 8601 format string with time zone
|
||||
(ending with `Z` for UTC).
|
||||
* Byte arrays (`bytes`) are written as a string with Base64 encoding.
|
||||
* UUIDs (`uuid.UUID`) are written as a UUID string as per RFC 4122.
|
||||
* Enumerations yield their enumeration value.
|
||||
* Containers (e.g. `list`, `dict`, `set`, `tuple`) are processed recursively.
|
||||
* Complex objects with properties (including data class types) generate dictionaries of key-value pairs.
|
||||
|
||||
:raises TypeError: A serializer engine cannot be constructed for the input type.
|
||||
"""
|
||||
|
||||
if context is None:
|
||||
if isinstance(typ, type):
|
||||
context = sys.modules[typ.__module__]
|
||||
|
||||
return _get_serializer(typ, context)
|
||||
|
||||
|
||||
def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
|
||||
if isinstance(typ, (str, typing.ForwardRef)):
|
||||
if context is None:
|
||||
raise TypeError(f"missing context for evaluating type: {typ}")
|
||||
|
||||
typ = evaluate_type(typ, context)
|
||||
|
||||
if isinstance(typ, type):
|
||||
return _fetch_serializer(typ)
|
||||
else:
|
||||
# special forms are not always hashable
|
||||
return _create_serializer(typ, context)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _fetch_serializer(typ: type) -> Serializer:
|
||||
context = sys.modules[typ.__module__]
|
||||
return _create_serializer(typ, context)
|
||||
|
||||
|
||||
def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
|
||||
# check for well-known types
|
||||
if typ is type(None):
|
||||
return NoneSerializer()
|
||||
elif typ is bool:
|
||||
return BoolSerializer()
|
||||
elif typ is int:
|
||||
return IntSerializer()
|
||||
elif typ is float:
|
||||
return FloatSerializer()
|
||||
elif typ is str:
|
||||
return StringSerializer()
|
||||
elif typ is bytes:
|
||||
return BytesSerializer()
|
||||
elif typ is datetime.datetime:
|
||||
return DateTimeSerializer()
|
||||
elif typ is datetime.date:
|
||||
return DateSerializer()
|
||||
elif typ is datetime.time:
|
||||
return TimeSerializer()
|
||||
elif typ is uuid.UUID:
|
||||
return UUIDSerializer()
|
||||
elif typ is ipaddress.IPv4Address:
|
||||
return IPv4Serializer()
|
||||
elif typ is ipaddress.IPv6Address:
|
||||
return IPv6Serializer()
|
||||
|
||||
# dynamically-typed collection types
|
||||
if typ is list:
|
||||
return UntypedListSerializer()
|
||||
elif typ is dict:
|
||||
return UntypedDictSerializer()
|
||||
elif typ is set:
|
||||
return UntypedSetSerializer()
|
||||
elif typ is tuple:
|
||||
return UntypedTupleSerializer()
|
||||
|
||||
# generic types (e.g. list, dict, set, etc.)
|
||||
origin_type = typing.get_origin(typ)
|
||||
if origin_type is list:
|
||||
(list_item_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return TypedListSerializer(list_item_type, context)
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
if key_type is str:
|
||||
return TypedStringDictSerializer(value_type, context)
|
||||
elif issubclass(key_type, enum.Enum):
|
||||
return TypedEnumDictSerializer(key_type, value_type, context)
|
||||
elif origin_type is set:
|
||||
(set_member_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return TypedSetSerializer(set_member_type, context)
|
||||
elif origin_type is tuple:
|
||||
return TypedTupleSerializer(typing.get_args(typ), context)
|
||||
elif origin_type is Union:
|
||||
return UnionSerializer()
|
||||
elif origin_type is Literal:
|
||||
return LiteralSerializer(typing.get_args(typ), context)
|
||||
|
||||
if is_type_annotated(typ):
|
||||
return create_serializer(unwrap_annotated_type(typ))
|
||||
|
||||
# check if object has custom serialization method
|
||||
convert_func = getattr(typ, "to_json", None)
|
||||
if callable(convert_func):
|
||||
return CustomSerializer(convert_func)
|
||||
|
||||
if is_type_enum(typ):
|
||||
return EnumSerializer()
|
||||
if is_dataclass_type(typ):
|
||||
return DataclassSerializer(typ, context)
|
||||
if is_named_tuple_type(typ):
|
||||
if getattr(typ, "__annotations__", None):
|
||||
return TypedNamedTupleSerializer(typ, context)
|
||||
else:
|
||||
return UntypedNamedTupleSerializer(typ)
|
||||
|
||||
# fail early if caller passes an object with an exotic type
|
||||
if not isinstance(typ, type) or typ is FunctionType or typ is MethodType or typ is type or typ is ModuleType:
|
||||
raise TypeError(f"object of type {typ} cannot be represented in JSON")
|
||||
|
||||
if get_resolved_hints(typ):
|
||||
return TypedClassSerializer(typ, context)
|
||||
else:
|
||||
return UntypedClassSerializer()
|
||||
|
||||
|
||||
def object_to_json(obj: Any) -> JsonType:
|
||||
"""
|
||||
Converts a Python object to a representation that can be exported to JSON.
|
||||
|
||||
* Fundamental types (e.g. numeric types) are written as is.
|
||||
* Date and time types are serialized in the ISO 8601 format with time zone.
|
||||
* A byte array is written as a string with Base64 encoding.
|
||||
* UUIDs are written as a UUID string.
|
||||
* Enumerations are written as their value.
|
||||
* Containers (e.g. `list`, `dict`, `set`, `tuple`) are exported recursively.
|
||||
* Objects with properties (including data class types) are converted to a dictionaries of key-value pairs.
|
||||
"""
|
||||
|
||||
typ: type = type(obj)
|
||||
generator = create_serializer(typ)
|
||||
return generator.generate(obj)
|
27
llama_stack/strong_typing/slots.py
Normal file
27
llama_stack/strong_typing/slots.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Tuple, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SlotsMeta(type):
|
||||
def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T:
|
||||
# caller may have already provided slots, in which case just retain them and keep going
|
||||
slots: Tuple[str, ...] = ns.get("__slots__", ())
|
||||
|
||||
# add fields with type annotations to slots
|
||||
annotations: Dict[str, Any] = ns.get("__annotations__", {})
|
||||
members = tuple(member for member in annotations.keys() if member not in slots)
|
||||
|
||||
# assign slots
|
||||
ns["__slots__"] = slots + tuple(members)
|
||||
return super().__new__(cls, name, bases, ns) # type: ignore
|
||||
|
||||
|
||||
class Slots(metaclass=SlotsMeta):
|
||||
pass
|
89
llama_stack/strong_typing/topological.py
Normal file
89
llama_stack/strong_typing/topological.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Type-safe data interchange for Python data classes.
|
||||
|
||||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar
|
||||
|
||||
from .inspection import TypeCollector
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
|
||||
"""
|
||||
Performs a topological sort of a graph.
|
||||
|
||||
Nodes with no outgoing edges are first. Nodes with no incoming edges are last.
|
||||
The topological ordering is not unique.
|
||||
|
||||
:param graph: A dictionary of mappings from nodes to adjacent nodes. Keys and set members must be hashable.
|
||||
:returns: The list of nodes in topological order.
|
||||
"""
|
||||
|
||||
# empty list that will contain the sorted nodes (in reverse order)
|
||||
ordered: List[T] = []
|
||||
|
||||
seen: Dict[T, bool] = {}
|
||||
|
||||
def _visit(n: T) -> None:
|
||||
status = seen.get(n)
|
||||
if status is not None:
|
||||
if status: # node has a permanent mark
|
||||
return
|
||||
else: # node has a temporary mark
|
||||
raise RuntimeError(f"cycle detected in graph for node {n}")
|
||||
|
||||
seen[n] = False # apply temporary mark
|
||||
for m in graph[n]: # visit all adjacent nodes
|
||||
if m != n: # ignore self-referencing nodes
|
||||
_visit(m)
|
||||
|
||||
seen[n] = True # apply permanent mark
|
||||
ordered.append(n)
|
||||
|
||||
for n in graph.keys():
|
||||
_visit(n)
|
||||
|
||||
return ordered
|
||||
|
||||
|
||||
def type_topological_sort(
|
||||
types: Iterable[type],
|
||||
dependency_fn: Optional[Callable[[type], Iterable[type]]] = None,
|
||||
) -> List[type]:
|
||||
"""
|
||||
Performs a topological sort of a list of types.
|
||||
|
||||
Types that don't depend on other types (i.e. fundamental types) are first. Types on which no other types depend
|
||||
are last. The topological ordering is not unique.
|
||||
|
||||
:param types: A list of types (simple or composite).
|
||||
:param dependency_fn: Returns a list of additional dependencies for a class (e.g. classes referenced by a foreign key).
|
||||
:returns: The list of types in topological order.
|
||||
"""
|
||||
|
||||
if not all(isinstance(typ, type) for typ in types):
|
||||
raise TypeError("expected a list of types")
|
||||
|
||||
collector = TypeCollector()
|
||||
collector.traverse_all(types)
|
||||
graph = collector.graph
|
||||
|
||||
if dependency_fn:
|
||||
new_types: Set[type] = set()
|
||||
for source_type, references in graph.items():
|
||||
dependent_types = dependency_fn(source_type)
|
||||
references.update(dependent_types)
|
||||
new_types.update(dependent_types)
|
||||
for new_type in new_types:
|
||||
graph[new_type] = set()
|
||||
|
||||
return topological_sort(graph)
|
Loading…
Add table
Add a link
Reference in a new issue