forked from phoenix-oss/llama-stack-mirror
		
	See https://swagger.io/docs/specification/v3_0/data-models/inheritance-and-polymorphism/#discriminator When specifying discriminators, mapping must be specified unless the value of the discriminator is the subtype itself (which in our case is not.) The changes in the YAML are self-explanatory.
		
			
				
	
	
		
			460 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			460 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import copy
 | |
| import dataclasses
 | |
| import datetime
 | |
| import decimal
 | |
| import enum
 | |
| import ipaddress
 | |
| import math
 | |
| import re
 | |
| import sys
 | |
| import types
 | |
| import typing
 | |
| import uuid
 | |
| from dataclasses import dataclass
 | |
| from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
 | |
| 
 | |
| from .auxiliary import (
 | |
|     Alias,
 | |
|     Annotated,
 | |
|     float32,
 | |
|     float64,
 | |
|     int16,
 | |
|     int32,
 | |
|     int64,
 | |
|     MaxLength,
 | |
|     Precision,
 | |
| )
 | |
| from .core import JsonType, Schema
 | |
| from .docstring import Docstring, DocstringParam
 | |
| from .inspection import TypeLike
 | |
| from .serialization import json_to_object, object_to_json
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaNode:
 | |
|     title: Optional[str]
 | |
|     description: Optional[str]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaType(JsonSchemaNode):
 | |
|     type: str
 | |
|     format: Optional[str]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaBoolean(JsonSchemaType):
 | |
|     type: Literal["boolean"]
 | |
|     const: Optional[bool]
 | |
|     default: Optional[bool]
 | |
|     examples: Optional[List[bool]]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaInteger(JsonSchemaType):
 | |
|     type: Literal["integer"]
 | |
|     const: Optional[int]
 | |
|     default: Optional[int]
 | |
|     examples: Optional[List[int]]
 | |
|     enum: Optional[List[int]]
 | |
|     minimum: Optional[int]
 | |
|     maximum: Optional[int]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaNumber(JsonSchemaType):
 | |
|     type: Literal["number"]
 | |
|     const: Optional[float]
 | |
|     default: Optional[float]
 | |
|     examples: Optional[List[float]]
 | |
|     minimum: Optional[float]
 | |
|     maximum: Optional[float]
 | |
|     exclusiveMinimum: Optional[float]
 | |
|     exclusiveMaximum: Optional[float]
 | |
|     multipleOf: Optional[float]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaString(JsonSchemaType):
 | |
|     type: Literal["string"]
 | |
|     const: Optional[str]
 | |
|     default: Optional[str]
 | |
|     examples: Optional[List[str]]
 | |
|     enum: Optional[List[str]]
 | |
|     minLength: Optional[int]
 | |
|     maxLength: Optional[int]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaArray(JsonSchemaType):
 | |
|     type: Literal["array"]
 | |
|     items: "JsonSchemaAny"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaObject(JsonSchemaType):
 | |
|     type: Literal["object"]
 | |
|     properties: Optional[Dict[str, "JsonSchemaAny"]]
 | |
|     additionalProperties: Optional[bool]
 | |
|     required: Optional[List[str]]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaRef(JsonSchemaNode):
 | |
|     ref: Annotated[str, Alias("$ref")]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaAllOf(JsonSchemaNode):
 | |
|     allOf: List["JsonSchemaAny"]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaAnyOf(JsonSchemaNode):
 | |
|     anyOf: List["JsonSchemaAny"]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class Discriminator:
 | |
|     propertyName: str
 | |
|     mapping: Dict[str, str]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaOneOf(JsonSchemaNode):
 | |
|     oneOf: List["JsonSchemaAny"]
 | |
|     discriminator: Optional[Discriminator]
 | |
| 
 | |
| 
 | |
| JsonSchemaAny = Union[
 | |
|     JsonSchemaRef,
 | |
|     JsonSchemaBoolean,
 | |
|     JsonSchemaInteger,
 | |
|     JsonSchemaNumber,
 | |
|     JsonSchemaString,
 | |
|     JsonSchemaArray,
 | |
|     JsonSchemaObject,
 | |
|     JsonSchemaOneOf,
 | |
| ]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class JsonSchemaTopLevelObject(JsonSchemaObject):
 | |
|     schema: Annotated[str, Alias("$schema")]
 | |
|     definitions: Optional[Dict[str, JsonSchemaAny]]
 | |
| 
 | |
| 
 | |
| def integer_range_to_type(min_value: float, max_value: float) -> type:
 | |
|     if min_value >= -(2**15) and max_value < 2**15:
 | |
|         return int16
 | |
|     elif min_value >= -(2**31) and max_value < 2**31:
 | |
|         return int32
 | |
|     else:
 | |
|         return int64
 | |
| 
 | |
| 
 | |
| def enum_safe_name(name: str) -> str:
 | |
|     name = re.sub(r"\W", "_", name)
 | |
|     is_dunder = name.startswith("__")
 | |
|     is_sunder = name.startswith("_") and name.endswith("_")
 | |
|     if is_dunder or is_sunder:  # provide an alternative for dunder and sunder names
 | |
|         name = f"v{name}"
 | |
|     return name
 | |
| 
 | |
| 
 | |
| def enum_values_to_type(
 | |
|     module: types.ModuleType,
 | |
|     name: str,
 | |
|     values: Dict[str, Any],
 | |
|     title: Optional[str] = None,
 | |
|     description: Optional[str] = None,
 | |
| ) -> Type[enum.Enum]:
 | |
|     enum_class: Type[enum.Enum] = enum.Enum(name, values)  # type: ignore
 | |
| 
 | |
|     # assign the newly created type to the same module where the defining class is
 | |
|     enum_class.__module__ = module.__name__
 | |
|     enum_class.__doc__ = str(
 | |
|         Docstring(short_description=title, long_description=description)
 | |
|     )
 | |
|     setattr(module, name, enum_class)
 | |
| 
 | |
|     return enum.unique(enum_class)
 | |
| 
 | |
| 
 | |
| def schema_to_type(
 | |
|     schema: Schema, *, module: types.ModuleType, class_name: str
 | |
| ) -> TypeLike:
 | |
|     """
 | |
|     Creates a Python type from a JSON schema.
 | |
| 
 | |
|     :param schema: The JSON schema that the types would correspond to.
 | |
|     :param module: The module in which to create the new types.
 | |
|     :param class_name: The name assigned to the top-level class.
 | |
|     """
 | |
| 
 | |
|     top_node = typing.cast(
 | |
|         JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)
 | |
|     )
 | |
|     if top_node.definitions is not None:
 | |
|         for type_name, type_node in top_node.definitions.items():
 | |
|             type_def = node_to_typedef(module, type_name, type_node)
 | |
|             if type_def.default is not dataclasses.MISSING:
 | |
|                 raise TypeError("disallowed: `default` for top-level type definitions")
 | |
| 
 | |
|             setattr(type_def.type, "__module__", module.__name__)
 | |
|             setattr(module, type_name, type_def.type)
 | |
| 
 | |
|     return node_to_typedef(module, class_name, top_node).type
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class TypeDef:
 | |
|     type: TypeLike
 | |
|     default: Any = dataclasses.MISSING
 | |
| 
 | |
| 
 | |
| def json_to_value(target_type: TypeLike, data: JsonType) -> Any:
 | |
|     if data is not None:
 | |
|         return json_to_object(target_type, data)
 | |
|     else:
 | |
|         return dataclasses.MISSING
 | |
| 
 | |
| 
 | |
| def node_to_typedef(
 | |
|     module: types.ModuleType, context: str, node: JsonSchemaNode
 | |
| ) -> TypeDef:
 | |
|     if isinstance(node, JsonSchemaRef):
 | |
|         match_obj = re.match(r"^#/definitions/(\w+)$", node.ref)
 | |
|         if not match_obj:
 | |
|             raise ValueError(f"invalid reference: {node.ref}")
 | |
| 
 | |
|         type_name = match_obj.group(1)
 | |
|         return TypeDef(getattr(module, type_name), dataclasses.MISSING)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaBoolean):
 | |
|         if node.const is not None:
 | |
|             return TypeDef(Literal[node.const], dataclasses.MISSING)
 | |
| 
 | |
|         default = json_to_value(bool, node.default)
 | |
|         return TypeDef(bool, default)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaInteger):
 | |
|         if node.const is not None:
 | |
|             return TypeDef(Literal[node.const], dataclasses.MISSING)
 | |
| 
 | |
|         integer_type: TypeLike
 | |
|         if node.format == "int16":
 | |
|             integer_type = int16
 | |
|         elif node.format == "int32":
 | |
|             integer_type = int32
 | |
|         elif node.format == "int64":
 | |
|             integer_type = int64
 | |
|         else:
 | |
|             if node.enum is not None:
 | |
|                 integer_type = integer_range_to_type(min(node.enum), max(node.enum))
 | |
|             elif node.minimum is not None and node.maximum is not None:
 | |
|                 integer_type = integer_range_to_type(node.minimum, node.maximum)
 | |
|             else:
 | |
|                 integer_type = int
 | |
| 
 | |
|         default = json_to_value(integer_type, node.default)
 | |
|         return TypeDef(integer_type, default)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaNumber):
 | |
|         if node.const is not None:
 | |
|             return TypeDef(Literal[node.const], dataclasses.MISSING)
 | |
| 
 | |
|         number_type: TypeLike
 | |
|         if node.format == "float32":
 | |
|             number_type = float32
 | |
|         elif node.format == "float64":
 | |
|             number_type = float64
 | |
|         else:
 | |
|             if (
 | |
|                 node.exclusiveMinimum is not None
 | |
|                 and node.exclusiveMaximum is not None
 | |
|                 and node.exclusiveMinimum == -node.exclusiveMaximum
 | |
|             ):
 | |
|                 integer_digits = round(math.log10(node.exclusiveMaximum))
 | |
|             else:
 | |
|                 integer_digits = None
 | |
| 
 | |
|             if node.multipleOf is not None:
 | |
|                 decimal_digits = -round(math.log10(node.multipleOf))
 | |
|             else:
 | |
|                 decimal_digits = None
 | |
| 
 | |
|             if integer_digits is not None and decimal_digits is not None:
 | |
|                 number_type = Annotated[
 | |
|                     decimal.Decimal,
 | |
|                     Precision(integer_digits + decimal_digits, decimal_digits),
 | |
|                 ]
 | |
|             else:
 | |
|                 number_type = float
 | |
| 
 | |
|         default = json_to_value(number_type, node.default)
 | |
|         return TypeDef(number_type, default)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaString):
 | |
|         if node.const is not None:
 | |
|             return TypeDef(Literal[node.const], dataclasses.MISSING)
 | |
| 
 | |
|         string_type: TypeLike
 | |
|         if node.format == "date-time":
 | |
|             string_type = datetime.datetime
 | |
|         elif node.format == "uuid":
 | |
|             string_type = uuid.UUID
 | |
|         elif node.format == "ipv4":
 | |
|             string_type = ipaddress.IPv4Address
 | |
|         elif node.format == "ipv6":
 | |
|             string_type = ipaddress.IPv6Address
 | |
| 
 | |
|         elif node.enum is not None:
 | |
|             string_type = enum_values_to_type(
 | |
|                 module,
 | |
|                 context,
 | |
|                 {enum_safe_name(e): e for e in node.enum},
 | |
|                 title=node.title,
 | |
|                 description=node.description,
 | |
|             )
 | |
| 
 | |
|         elif node.maxLength is not None:
 | |
|             string_type = Annotated[str, MaxLength(node.maxLength)]
 | |
|         else:
 | |
|             string_type = str
 | |
| 
 | |
|         default = json_to_value(string_type, node.default)
 | |
|         return TypeDef(string_type, default)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaArray):
 | |
|         type_def = node_to_typedef(module, context, node.items)
 | |
|         if type_def.default is not dataclasses.MISSING:
 | |
|             raise TypeError("disallowed: `default` for array element type")
 | |
|         list_type = List[(type_def.type,)]  # type: ignore
 | |
|         return TypeDef(list_type, dataclasses.MISSING)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaObject):
 | |
|         if node.properties is None:
 | |
|             return TypeDef(JsonType, dataclasses.MISSING)
 | |
| 
 | |
|         if node.additionalProperties is None or node.additionalProperties is not False:
 | |
|             raise TypeError("expected: `additionalProperties` equals `false`")
 | |
| 
 | |
|         required = node.required if node.required is not None else []
 | |
| 
 | |
|         class_name = context
 | |
| 
 | |
|         fields: List[Tuple[str, Any, dataclasses.Field]] = []
 | |
|         params: Dict[str, DocstringParam] = {}
 | |
|         for prop_name, prop_node in node.properties.items():
 | |
|             type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node)
 | |
|             if prop_name in required:
 | |
|                 prop_type = type_def.type
 | |
|             else:
 | |
|                 prop_type = Union[(None, type_def.type)]
 | |
|             fields.append(
 | |
|                 (prop_name, prop_type, dataclasses.field(default=type_def.default))
 | |
|             )
 | |
|             prop_desc = prop_node.title or prop_node.description
 | |
|             if prop_desc is not None:
 | |
|                 params[prop_name] = DocstringParam(prop_name, prop_desc)
 | |
| 
 | |
|         fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING)
 | |
|         if sys.version_info >= (3, 12):
 | |
|             class_type = dataclasses.make_dataclass(
 | |
|                 class_name, fields, module=module.__name__
 | |
|             )
 | |
|         else:
 | |
|             class_type = dataclasses.make_dataclass(
 | |
|                 class_name, fields, namespace={"__module__": module.__name__}
 | |
|             )
 | |
|         class_type.__doc__ = str(
 | |
|             Docstring(
 | |
|                 short_description=node.title,
 | |
|                 long_description=node.description,
 | |
|                 params=params,
 | |
|             )
 | |
|         )
 | |
|         setattr(module, class_name, class_type)
 | |
|         return TypeDef(class_type, dataclasses.MISSING)
 | |
| 
 | |
|     elif isinstance(node, JsonSchemaOneOf):
 | |
|         union_defs = tuple(node_to_typedef(module, context, n) for n in node.oneOf)
 | |
|         if any(d.default is not dataclasses.MISSING for d in union_defs):
 | |
|             raise TypeError("disallowed: `default` for union member type")
 | |
|         union_types = tuple(d.type for d in union_defs)
 | |
|         return TypeDef(Union[union_types], dataclasses.MISSING)
 | |
| 
 | |
|     raise NotImplementedError()
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class SchemaFlatteningOptions:
 | |
|     qualified_names: bool = False
 | |
|     recursive: bool = False
 | |
| 
 | |
| 
 | |
| def flatten_schema(
 | |
|     schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None
 | |
| ) -> Schema:
 | |
|     top_node = typing.cast(
 | |
|         JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)
 | |
|     )
 | |
|     flattener = SchemaFlattener(options)
 | |
|     obj = flattener.flatten(top_node)
 | |
|     return typing.cast(Schema, object_to_json(obj))
 | |
| 
 | |
| 
 | |
| class SchemaFlattener:
 | |
|     options: SchemaFlatteningOptions
 | |
| 
 | |
|     def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None:
 | |
|         self.options = options or SchemaFlatteningOptions()
 | |
| 
 | |
|     def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject:
 | |
|         if source_node.type != "object":
 | |
|             return source_node
 | |
| 
 | |
|         source_props = source_node.properties or {}
 | |
|         target_props: Dict[str, JsonSchemaAny] = {}
 | |
| 
 | |
|         source_reqs = source_node.required or []
 | |
|         target_reqs: List[str] = []
 | |
| 
 | |
|         for name, prop in source_props.items():
 | |
|             if not isinstance(prop, JsonSchemaObject):
 | |
|                 target_props[name] = prop
 | |
|                 if name in source_reqs:
 | |
|                     target_reqs.append(name)
 | |
|                 continue
 | |
| 
 | |
|             if self.options.recursive:
 | |
|                 obj = self.flatten(prop)
 | |
|             else:
 | |
|                 obj = prop
 | |
|             if obj.properties is not None:
 | |
|                 if self.options.qualified_names:
 | |
|                     target_props.update(
 | |
|                         (f"{name}.{n}", p) for n, p in obj.properties.items()
 | |
|                     )
 | |
|                 else:
 | |
|                     target_props.update(obj.properties.items())
 | |
|             if obj.required is not None:
 | |
|                 if self.options.qualified_names:
 | |
|                     target_reqs.extend(f"{name}.{n}" for n in obj.required)
 | |
|                 else:
 | |
|                     target_reqs.extend(obj.required)
 | |
| 
 | |
|         target_node = copy.copy(source_node)
 | |
|         target_node.properties = target_props or None
 | |
|         target_node.additionalProperties = False
 | |
|         target_node.required = target_reqs or None
 | |
|         return target_node
 |