# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import copy import dataclasses import datetime import decimal import enum import ipaddress import math import re import sys import types import typing import uuid from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union from .auxiliary import ( Alias, Annotated, MaxLength, Precision, float32, float64, int16, int32, int64, ) from .core import JsonType, Schema from .docstring import Docstring, DocstringParam from .inspection import TypeLike from .serialization import json_to_object, object_to_json T = TypeVar("T") @dataclass class JsonSchemaNode: title: Optional[str] description: Optional[str] @dataclass class JsonSchemaType(JsonSchemaNode): type: str format: Optional[str] @dataclass class JsonSchemaBoolean(JsonSchemaType): type: Literal["boolean"] const: Optional[bool] default: Optional[bool] examples: Optional[List[bool]] @dataclass class JsonSchemaInteger(JsonSchemaType): type: Literal["integer"] const: Optional[int] default: Optional[int] examples: Optional[List[int]] enum: Optional[List[int]] minimum: Optional[int] maximum: Optional[int] @dataclass class JsonSchemaNumber(JsonSchemaType): type: Literal["number"] const: Optional[float] default: Optional[float] examples: Optional[List[float]] minimum: Optional[float] maximum: Optional[float] exclusiveMinimum: Optional[float] exclusiveMaximum: Optional[float] multipleOf: Optional[float] @dataclass class JsonSchemaString(JsonSchemaType): type: Literal["string"] const: Optional[str] default: Optional[str] examples: Optional[List[str]] enum: Optional[List[str]] minLength: Optional[int] maxLength: Optional[int] @dataclass class JsonSchemaArray(JsonSchemaType): type: Literal["array"] items: "JsonSchemaAny" @dataclass class JsonSchemaObject(JsonSchemaType): type: Literal["object"] properties: Optional[Dict[str, "JsonSchemaAny"]] additionalProperties: Optional[bool] required: Optional[List[str]] @dataclass class JsonSchemaRef(JsonSchemaNode): ref: Annotated[str, Alias("$ref")] @dataclass class JsonSchemaAllOf(JsonSchemaNode): allOf: List["JsonSchemaAny"] @dataclass class JsonSchemaAnyOf(JsonSchemaNode): anyOf: List["JsonSchemaAny"] @dataclass class Discriminator: propertyName: str mapping: Dict[str, str] @dataclass class JsonSchemaOneOf(JsonSchemaNode): oneOf: List["JsonSchemaAny"] discriminator: Optional[Discriminator] JsonSchemaAny = Union[ JsonSchemaRef, JsonSchemaBoolean, JsonSchemaInteger, JsonSchemaNumber, JsonSchemaString, JsonSchemaArray, JsonSchemaObject, JsonSchemaOneOf, ] @dataclass class JsonSchemaTopLevelObject(JsonSchemaObject): schema: Annotated[str, Alias("$schema")] definitions: Optional[Dict[str, JsonSchemaAny]] def integer_range_to_type(min_value: float, max_value: float) -> type: if min_value >= -(2**15) and max_value < 2**15: return int16 elif min_value >= -(2**31) and max_value < 2**31: return int32 else: return int64 def enum_safe_name(name: str) -> str: name = re.sub(r"\W", "_", name) is_dunder = name.startswith("__") is_sunder = name.startswith("_") and name.endswith("_") if is_dunder or is_sunder: # provide an alternative for dunder and sunder names name = f"v{name}" return name def enum_values_to_type( module: types.ModuleType, name: str, values: Dict[str, Any], title: Optional[str] = None, description: Optional[str] = None, ) -> Type[enum.Enum]: enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore # assign the newly created type to the same module where the defining class is enum_class.__module__ = module.__name__ enum_class.__doc__ = str(Docstring(short_description=title, long_description=description)) setattr(module, name, enum_class) return enum.unique(enum_class) def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) -> TypeLike: """ Creates a Python type from a JSON schema. :param schema: The JSON schema that the types would correspond to. :param module: The module in which to create the new types. :param class_name: The name assigned to the top-level class. """ top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) if top_node.definitions is not None: for type_name, type_node in top_node.definitions.items(): type_def = node_to_typedef(module, type_name, type_node) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for top-level type definitions") setattr(type_def.type, "__module__", module.__name__) setattr(module, type_name, type_def.type) return node_to_typedef(module, class_name, top_node).type @dataclass class TypeDef: type: TypeLike default: Any = dataclasses.MISSING def json_to_value(target_type: TypeLike, data: JsonType) -> Any: if data is not None: return json_to_object(target_type, data) else: return dataclasses.MISSING def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode) -> TypeDef: if isinstance(node, JsonSchemaRef): match_obj = re.match(r"^#/definitions/(\w+)$", node.ref) if not match_obj: raise ValueError(f"invalid reference: {node.ref}") type_name = match_obj.group(1) return TypeDef(getattr(module, type_name), dataclasses.MISSING) elif isinstance(node, JsonSchemaBoolean): if node.const is not None: return TypeDef(Literal[node.const], dataclasses.MISSING) default = json_to_value(bool, node.default) return TypeDef(bool, default) elif isinstance(node, JsonSchemaInteger): if node.const is not None: return TypeDef(Literal[node.const], dataclasses.MISSING) integer_type: TypeLike if node.format == "int16": integer_type = int16 elif node.format == "int32": integer_type = int32 elif node.format == "int64": integer_type = int64 else: if node.enum is not None: integer_type = integer_range_to_type(min(node.enum), max(node.enum)) elif node.minimum is not None and node.maximum is not None: integer_type = integer_range_to_type(node.minimum, node.maximum) else: integer_type = int default = json_to_value(integer_type, node.default) return TypeDef(integer_type, default) elif isinstance(node, JsonSchemaNumber): if node.const is not None: return TypeDef(Literal[node.const], dataclasses.MISSING) number_type: TypeLike if node.format == "float32": number_type = float32 elif node.format == "float64": number_type = float64 else: if ( node.exclusiveMinimum is not None and node.exclusiveMaximum is not None and node.exclusiveMinimum == -node.exclusiveMaximum ): integer_digits = round(math.log10(node.exclusiveMaximum)) else: integer_digits = None if node.multipleOf is not None: decimal_digits = -round(math.log10(node.multipleOf)) else: decimal_digits = None if integer_digits is not None and decimal_digits is not None: number_type = Annotated[ decimal.Decimal, Precision(integer_digits + decimal_digits, decimal_digits), ] else: number_type = float default = json_to_value(number_type, node.default) return TypeDef(number_type, default) elif isinstance(node, JsonSchemaString): if node.const is not None: return TypeDef(Literal[node.const], dataclasses.MISSING) string_type: TypeLike if node.format == "date-time": string_type = datetime.datetime elif node.format == "uuid": string_type = uuid.UUID elif node.format == "ipv4": string_type = ipaddress.IPv4Address elif node.format == "ipv6": string_type = ipaddress.IPv6Address elif node.enum is not None: string_type = enum_values_to_type( module, context, {enum_safe_name(e): e for e in node.enum}, title=node.title, description=node.description, ) elif node.maxLength is not None: string_type = Annotated[str, MaxLength(node.maxLength)] else: string_type = str default = json_to_value(string_type, node.default) return TypeDef(string_type, default) elif isinstance(node, JsonSchemaArray): type_def = node_to_typedef(module, context, node.items) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for array element type") list_type = List[(type_def.type,)] # type: ignore return TypeDef(list_type, dataclasses.MISSING) elif isinstance(node, JsonSchemaObject): if node.properties is None: return TypeDef(JsonType, dataclasses.MISSING) if node.additionalProperties is None or node.additionalProperties is not False: raise TypeError("expected: `additionalProperties` equals `false`") required = node.required if node.required is not None else [] class_name = context fields: List[Tuple[str, Any, dataclasses.Field]] = [] params: Dict[str, DocstringParam] = {} for prop_name, prop_node in node.properties.items(): type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node) if prop_name in required: prop_type = type_def.type else: prop_type = Union[(None, type_def.type)] fields.append((prop_name, prop_type, dataclasses.field(default=type_def.default))) prop_desc = prop_node.title or prop_node.description if prop_desc is not None: params[prop_name] = DocstringParam(prop_name, prop_desc) fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING) if sys.version_info >= (3, 12): class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__) else: class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__}) class_type.__doc__ = str( Docstring( short_description=node.title, long_description=node.description, params=params, ) ) setattr(module, class_name, class_type) return TypeDef(class_type, dataclasses.MISSING) elif isinstance(node, JsonSchemaOneOf): union_defs = tuple(node_to_typedef(module, context, n) for n in node.oneOf) if any(d.default is not dataclasses.MISSING for d in union_defs): raise TypeError("disallowed: `default` for union member type") union_types = tuple(d.type for d in union_defs) return TypeDef(Union[union_types], dataclasses.MISSING) raise NotImplementedError() @dataclass class SchemaFlatteningOptions: qualified_names: bool = False recursive: bool = False def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema: top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) flattener = SchemaFlattener(options) obj = flattener.flatten(top_node) return typing.cast(Schema, object_to_json(obj)) class SchemaFlattener: options: SchemaFlatteningOptions def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None: self.options = options or SchemaFlatteningOptions() def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject: if source_node.type != "object": return source_node source_props = source_node.properties or {} target_props: Dict[str, JsonSchemaAny] = {} source_reqs = source_node.required or [] target_reqs: List[str] = [] for name, prop in source_props.items(): if not isinstance(prop, JsonSchemaObject): target_props[name] = prop if name in source_reqs: target_reqs.append(name) continue if self.options.recursive: obj = self.flatten(prop) else: obj = prop if obj.properties is not None: if self.options.qualified_names: target_props.update((f"{name}.{n}", p) for n, p in obj.properties.items()) else: target_props.update(obj.properties.items()) if obj.required is not None: if self.options.qualified_names: target_reqs.extend(f"{name}.{n}" for n in obj.required) else: target_reqs.extend(obj.required) target_node = copy.copy(source_node) target_node.properties = target_props or None target_node.additionalProperties = False target_node.required = target_reqs or None return target_node