move openapi from rfcs->docs

This commit is contained in:
Xi Yan 2024-09-18 16:09:17 -07:00
parent 21058be0c1
commit 2c1ad10710
15 changed files with 9532 additions and 1 deletions

View file

@ -1,9 +0,0 @@
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[<subdir>]/api/endpoints.py` using the `generate.py` utility.
Please install the following packages before running the script:
```
pip install python-openapi json-strong-typing fire PyYAML llama-models
```
Then simply run `sh run_openapi_generator.sh <OUTPUT_DIR>`

View file

@ -1,87 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
from datetime import datetime
from pathlib import Path
import fire
import yaml
from llama_models import schema_utils
# We do some monkey-patching to ensure our definitions only use the minimal
# (json_schema_type, webmethod) definitions from the llama_models package. For
# generation though, we need the full definitions and implementations from the
# (json-strong-typing) package.
from strong_typing.schema import json_schema_type
from .pyopenapi.options import Options
from .pyopenapi.specification import Info, Server
from .pyopenapi.utility import Specification
schema_utils.json_schema_type = json_schema_type
from llama_stack.apis.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = [
"/agentic_system/turn/create",
"/inference/chat_completion",
]
def patch_sse_stream_responses(spec: Specification):
for path, path_item in spec.document.paths.items():
if path in STREAMING_ENDPOINTS:
content = path_item.post.responses["200"].content.pop("application/json")
path_item.post.responses["200"].content["text/event-stream"] = content
def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist")
now = str(datetime.now())
print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
)
print("")
spec = Specification(
LlamaStack,
Options(
server=Server(url="http://any-hosted-llama-stack.com"),
info=Info(
title="[DRAFT] Llama Stack Specification",
version="0.0.1",
description="""This is the specification of the llama stack that provides
a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models. The specification is still in draft and subject to change.
Generated at """
+ now,
),
),
)
patch_sse_stream_responses(spec)
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)
with open(output_dir / "llama-stack-spec.html", "w") as fp:
spec.write_html(fp, pretty_print=True)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1 +0,0 @@
This is forked from https://github.com/hunyadi/pyopenapi

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,718 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import ipaddress
import typing
from typing import Any, Dict, Set, Union
from strong_typing.core import JsonType
from strong_typing.docstring import Docstring, parse_type
from strong_typing.inspection import (
is_generic_list,
is_type_optional,
is_type_union,
unwrap_generic_list,
unwrap_optional_type,
unwrap_union_types,
)
from strong_typing.name import python_type_to_name
from strong_typing.schema import (
get_schema_identifier,
JsonSchemaGenerator,
register_schema,
Schema,
SchemaOptions,
)
from strong_typing.serialization import json_dump_string, object_to_json
from .operations import (
EndpointOperation,
get_endpoint_events,
get_endpoint_operations,
HTTPMethod,
)
from .options import *
from .specification import (
Components,
Document,
Example,
ExampleRef,
MediaType,
Operation,
Parameter,
ParameterLocation,
PathItem,
RequestBody,
Response,
ResponseRef,
SchemaOrRef,
SchemaRef,
Tag,
TagGroup,
)
register_schema(
ipaddress.IPv4Address,
schema={
"type": "string",
"format": "ipv4",
"title": "IPv4 address",
"description": "IPv4 address, according to dotted-quad ABNF syntax as defined in RFC 2673, section 3.2.",
},
examples=["192.0.2.0", "198.51.100.1", "203.0.113.255"],
)
register_schema(
ipaddress.IPv6Address,
schema={
"type": "string",
"format": "ipv6",
"title": "IPv6 address",
"description": "IPv6 address, as defined in RFC 2373, section 2.2.",
},
examples=[
"FEDC:BA98:7654:3210:FEDC:BA98:7654:3210",
"1080:0:0:0:8:800:200C:417A",
"1080::8:800:200C:417A",
"FF01::101",
"::1",
],
)
def http_status_to_string(status_code: HTTPStatusCode) -> str:
"Converts an HTTP status code to a string."
if isinstance(status_code, HTTPStatus):
return str(status_code.value)
elif isinstance(status_code, int):
return str(status_code)
elif isinstance(status_code, str):
return status_code
else:
raise TypeError("expected: HTTP status code")
class SchemaBuilder:
schema_generator: JsonSchemaGenerator
schemas: Dict[str, Schema]
def __init__(self, schema_generator: JsonSchemaGenerator) -> None:
self.schema_generator = schema_generator
self.schemas = {}
def classdef_to_schema(self, typ: type) -> Schema:
"""
Converts a type to a JSON schema.
For nested types found in the type hierarchy, adds the type to the schema registry in the OpenAPI specification section `components`.
"""
type_schema, type_definitions = self.schema_generator.classdef_to_schema(typ)
# append schema to list of known schemas, to be used in OpenAPI's Components Object section
for ref, schema in type_definitions.items():
self._add_ref(ref, schema)
return type_schema
def classdef_to_named_schema(self, name: str, typ: type) -> Schema:
schema = self.classdef_to_schema(typ)
self._add_ref(name, schema)
return schema
def classdef_to_ref(self, typ: type) -> SchemaOrRef:
"""
Converts a type to a JSON schema, and if possible, returns a schema reference.
For composite types (such as classes), adds the type to the schema registry in the OpenAPI specification section `components`.
"""
type_schema = self.classdef_to_schema(typ)
if typ is str or typ is int or typ is float:
# represent simple types as themselves
return type_schema
type_name = get_schema_identifier(typ)
if type_name is not None:
return self._build_ref(type_name, type_schema)
try:
type_name = python_type_to_name(typ)
return self._build_ref(type_name, type_schema)
except TypeError:
pass
return type_schema
def _build_ref(self, type_name: str, type_schema: Schema) -> SchemaRef:
self._add_ref(type_name, type_schema)
return SchemaRef(type_name)
def _add_ref(self, type_name: str, type_schema: Schema) -> None:
if type_name not in self.schemas:
self.schemas[type_name] = type_schema
class ContentBuilder:
schema_builder: SchemaBuilder
schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]]
sample_transformer: Optional[Callable[[JsonType], JsonType]]
def __init__(
self,
schema_builder: SchemaBuilder,
schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]] = None,
sample_transformer: Optional[Callable[[JsonType], JsonType]] = None,
) -> None:
self.schema_builder = schema_builder
self.schema_transformer = schema_transformer
self.sample_transformer = sample_transformer
def build_content(
self, payload_type: type, examples: Optional[List[Any]] = None
) -> Dict[str, MediaType]:
"Creates the content subtree for a request or response."
if is_generic_list(payload_type):
media_type = "application/jsonl"
item_type = unwrap_generic_list(payload_type)
else:
media_type = "application/json"
item_type = payload_type
return {media_type: self.build_media_type(item_type, examples)}
def build_media_type(
self, item_type: type, examples: Optional[List[Any]] = None
) -> MediaType:
schema = self.schema_builder.classdef_to_ref(item_type)
if self.schema_transformer:
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = self.schema_transformer # type: ignore
schema = schema_transformer(schema)
if not examples:
return MediaType(schema=schema)
if len(examples) == 1:
return MediaType(schema=schema, example=self._build_example(examples[0]))
return MediaType(
schema=schema,
examples=self._build_examples(examples),
)
def _build_examples(
self, examples: List[Any]
) -> Dict[str, Union[Example, ExampleRef]]:
"Creates a set of several examples for a media type."
if self.sample_transformer:
sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer # type: ignore
else:
sample_transformer = lambda sample: sample
results: Dict[str, Union[Example, ExampleRef]] = {}
for example in examples:
value = sample_transformer(object_to_json(example))
hash_string = (
hashlib.md5(json_dump_string(value).encode("utf-8")).digest().hex()
)
name = f"ex-{hash_string}"
results[name] = Example(value=value)
return results
def _build_example(self, example: Any) -> Any:
"Creates a single example for a media type."
if self.sample_transformer:
sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer # type: ignore
else:
sample_transformer = lambda sample: sample
return sample_transformer(object_to_json(example))
@dataclass
class ResponseOptions:
"""
Configuration options for building a response for an operation.
:param type_descriptions: Maps each response type to a textual description (if available).
:param examples: A list of response examples.
:param status_catalog: Maps each response type to an HTTP status code.
:param default_status_code: HTTP status code assigned to responses that have no mapping.
"""
type_descriptions: Dict[type, str]
examples: Optional[List[Any]]
status_catalog: Dict[type, HTTPStatusCode]
default_status_code: HTTPStatusCode
@dataclass
class StatusResponse:
status_code: str
types: List[type] = dataclasses.field(default_factory=list)
examples: List[Any] = dataclasses.field(default_factory=list)
class ResponseBuilder:
content_builder: ContentBuilder
def __init__(self, content_builder: ContentBuilder) -> None:
self.content_builder = content_builder
def _get_status_responses(
self, options: ResponseOptions
) -> Dict[str, StatusResponse]:
status_responses: Dict[str, StatusResponse] = {}
for response_type in options.type_descriptions.keys():
status_code = http_status_to_string(
options.status_catalog.get(response_type, options.default_status_code)
)
# look up response for status code
if status_code not in status_responses:
status_responses[status_code] = StatusResponse(status_code)
status_response = status_responses[status_code]
# append response types that are assigned the given status code
status_response.types.append(response_type)
# append examples that have the matching response type
if options.examples:
status_response.examples.extend(
example
for example in options.examples
if isinstance(example, response_type)
)
return dict(sorted(status_responses.items()))
def build_response(
self, options: ResponseOptions
) -> Dict[str, Union[Response, ResponseRef]]:
"""
Groups responses that have the same status code.
"""
responses: Dict[str, Union[Response, ResponseRef]] = {}
status_responses = self._get_status_responses(options)
for status_code, status_response in status_responses.items():
response_types = tuple(status_response.types)
if len(response_types) > 1:
composite_response_type: type = Union[response_types] # type: ignore
else:
(response_type,) = response_types
composite_response_type = response_type
description = " **OR** ".join(
filter(
None,
(
options.type_descriptions[response_type]
for response_type in response_types
),
)
)
responses[status_code] = self._build_response(
response_type=composite_response_type,
description=description,
examples=status_response.examples or None,
)
return responses
def _build_response(
self,
response_type: type,
description: str,
examples: Optional[List[Any]] = None,
) -> Response:
"Creates a response subtree."
if response_type is not None:
return Response(
description=description,
content=self.content_builder.build_content(response_type, examples),
)
else:
return Response(description=description)
def schema_error_wrapper(schema: SchemaOrRef) -> Schema:
"Wraps an error output schema into a top-level error schema."
return {
"type": "object",
"properties": {
"error": schema, # type: ignore
},
"additionalProperties": False,
"required": [
"error",
],
}
def sample_error_wrapper(error: JsonType) -> JsonType:
"Wraps an error output sample into a top-level error sample."
return {"error": error}
class Generator:
endpoint: type
options: Options
schema_builder: SchemaBuilder
responses: Dict[str, Response]
def __init__(self, endpoint: type, options: Options) -> None:
self.endpoint = endpoint
self.options = options
schema_generator = JsonSchemaGenerator(
SchemaOptions(
definitions_path="#/components/schemas/",
use_examples=self.options.use_examples,
property_description_fun=options.property_description_fun,
)
)
self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {}
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
definition = f'<SchemaDefinition schemaRef="#/components/schemas/{ref}" />'
title = typing.cast(str, schema.get("title"))
description = typing.cast(str, schema.get("description"))
return Tag(
name=ref,
description="\n\n".join(
s for s in (title, description, definition) if s is not None
),
)
def _build_extra_tag_groups(
self, extra_types: Dict[str, List[type]]
) -> Dict[str, List[Tag]]:
"""
Creates a dictionary of tag group captions as keys, and tag lists as values.
:param extra_types: A dictionary of type categories and list of types in that category.
"""
extra_tags: Dict[str, List[Tag]] = {}
for category_name, category_items in extra_types.items():
tag_list: List[Tag] = []
for extra_type in category_items:
name = python_type_to_name(extra_type)
schema = self.schema_builder.classdef_to_named_schema(name, extra_type)
tag_list.append(self._build_type_tag(name, schema))
if tag_list:
extra_tags[category_name] = tag_list
return extra_tags
def _build_operation(self, op: EndpointOperation) -> Operation:
doc_string = parse_type(op.func_ref)
doc_params = dict(
(param.name, param.description) for param in doc_string.params.values()
)
# parameters passed in URL component path
path_parameters = [
Parameter(
name=param_name,
in_=ParameterLocation.Path,
description=doc_params.get(param_name),
required=True,
schema=self.schema_builder.classdef_to_ref(param_type),
)
for param_name, param_type in op.path_params
]
# parameters passed in URL component query string
query_parameters = []
for param_name, param_type in op.query_params:
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
required = False
else:
inner_type = param_type
required = True
query_parameter = Parameter(
name=param_name,
in_=ParameterLocation.Query,
description=doc_params.get(param_name),
required=required,
schema=self.schema_builder.classdef_to_ref(inner_type),
)
query_parameters.append(query_parameter)
# parameters passed anywhere
parameters = path_parameters + query_parameters
# data passed in payload
if op.request_params:
builder = ContentBuilder(self.schema_builder)
first = next(iter(op.request_params))
request_name, request_type = first
from dataclasses import make_dataclass
op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params)
requestBody = RequestBody(
content={
"application/json": builder.build_media_type(
request_type, op.request_examples
)
},
description=doc_params.get(request_name),
required=True,
)
else:
requestBody = None
# success response types
if doc_string.returns is None and is_type_union(op.response_type):
# split union of return types into a list of response types
success_type_docstring: Dict[type, Docstring] = {
typing.cast(type, item): parse_type(item)
for item in unwrap_union_types(op.response_type)
}
success_type_descriptions = {
item: doc_string.short_description
for item, doc_string in success_type_docstring.items()
if doc_string.short_description
}
else:
# use return type as a single response type
success_type_descriptions = {
op.response_type: (
doc_string.returns.description if doc_string.returns else "OK"
)
}
response_examples = op.response_examples or []
success_examples = [
example
for example in response_examples
if not isinstance(example, Exception)
]
content_builder = ContentBuilder(self.schema_builder)
response_builder = ResponseBuilder(content_builder)
response_options = ResponseOptions(
success_type_descriptions,
success_examples if self.options.use_examples else None,
self.options.success_responses,
"200",
)
responses = response_builder.build_response(response_options)
# failure response types
if doc_string.raises:
exception_types: Dict[type, str] = {
item.raise_type: item.description for item in doc_string.raises.values()
}
exception_examples = [
example
for example in response_examples
if isinstance(example, Exception)
]
if self.options.error_wrapper:
schema_transformer = schema_error_wrapper
sample_transformer = sample_error_wrapper
else:
schema_transformer = None
sample_transformer = None
content_builder = ContentBuilder(
self.schema_builder,
schema_transformer=schema_transformer,
sample_transformer=sample_transformer,
)
response_builder = ResponseBuilder(content_builder)
response_options = ResponseOptions(
exception_types,
exception_examples if self.options.use_examples else None,
self.options.error_responses,
"500",
)
responses.update(response_builder.build_response(response_options))
if op.event_type is not None:
builder = ContentBuilder(self.schema_builder)
callbacks = {
f"{op.func_name}_callback": {
"{$request.query.callback}": PathItem(
post=Operation(
requestBody=RequestBody(
content=builder.build_content(op.event_type)
),
responses={"200": Response(description="OK")},
)
)
}
}
else:
callbacks = None
return Operation(
tags=[op.defining_class.__name__],
summary=doc_string.short_description,
description=doc_string.long_description,
parameters=parameters,
requestBody=requestBody,
responses=responses,
callbacks=callbacks,
security=[] if op.public else None,
)
def generate(self) -> Document:
paths: Dict[str, PathItem] = {}
endpoint_classes: Set[type] = set()
for op in get_endpoint_operations(
self.endpoint, use_examples=self.options.use_examples
):
endpoint_classes.add(op.defining_class)
operation = self._build_operation(op)
if op.http_method is HTTPMethod.GET:
pathItem = PathItem(get=operation)
elif op.http_method is HTTPMethod.PUT:
pathItem = PathItem(put=operation)
elif op.http_method is HTTPMethod.POST:
pathItem = PathItem(post=operation)
elif op.http_method is HTTPMethod.DELETE:
pathItem = PathItem(delete=operation)
elif op.http_method is HTTPMethod.PATCH:
pathItem = PathItem(patch=operation)
else:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route()
if route in paths:
paths[route].update(pathItem)
else:
paths[route] = pathItem
operation_tags: List[Tag] = []
for cls in endpoint_classes:
doc_string = parse_type(cls)
operation_tags.append(
Tag(
name=cls.__name__,
description=doc_string.long_description,
displayName=doc_string.short_description,
)
)
# types that are produced/consumed by operations
type_tags = [
self._build_type_tag(ref, schema)
for ref, schema in self.schema_builder.schemas.items()
]
# types that are emitted by events
event_tags: List[Tag] = []
events = get_endpoint_events(self.endpoint)
for ref, event_type in events.items():
event_schema = self.schema_builder.classdef_to_named_schema(ref, event_type)
event_tags.append(self._build_type_tag(ref, event_schema))
# types that are explicitly declared
extra_tag_groups: Dict[str, List[Tag]] = {}
if self.options.extra_types is not None:
if isinstance(self.options.extra_types, list):
extra_tag_groups = self._build_extra_tag_groups(
{"AdditionalTypes": self.options.extra_types}
)
elif isinstance(self.options.extra_types, dict):
extra_tag_groups = self._build_extra_tag_groups(
self.options.extra_types
)
else:
raise TypeError(
f"type mismatch for collection of extra types: {type(self.options.extra_types)}"
)
# list all operations and types
tags: List[Tag] = []
tags.extend(operation_tags)
tags.extend(type_tags)
tags.extend(event_tags)
for extra_tag_group in extra_tag_groups.values():
tags.extend(extra_tag_group)
tag_groups = []
if operation_tags:
tag_groups.append(
TagGroup(
name=self.options.map("Operations"),
tags=sorted(tag.name for tag in operation_tags),
)
)
if type_tags:
tag_groups.append(
TagGroup(
name=self.options.map("Types"),
tags=sorted(tag.name for tag in type_tags),
)
)
if event_tags:
tag_groups.append(
TagGroup(
name=self.options.map("Events"),
tags=sorted(tag.name for tag in event_tags),
)
)
for caption, extra_tag_group in extra_tag_groups.items():
tag_groups.append(
TagGroup(
name=self.options.map(caption),
tags=sorted(tag.name for tag in extra_tag_group),
)
)
if self.options.default_security_scheme:
securitySchemes = {"Default": self.options.default_security_scheme}
else:
securitySchemes = None
return Document(
openapi=".".join(str(item) for item in self.options.version),
info=self.options.info,
jsonSchemaDialect=(
"https://json-schema.org/draft/2020-12/schema"
if self.options.version >= (3, 1, 0)
else None
),
servers=[self.options.server],
paths=paths,
components=Components(
schemas=self.schema_builder.schemas,
responses=self.responses,
securitySchemes=securitySchemes,
),
security=[{"Default": []}],
tags=tags,
tagGroups=tag_groups,
)

View file

@ -1,386 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import collections.abc
import enum
import inspect
import typing
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from strong_typing.inspection import (
get_signature,
is_type_enum,
is_type_optional,
unwrap_optional_type,
)
from termcolor import colored
def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]]
) -> Tuple[Optional[str], str]:
"""
Recognizes a prefix at the beginning of a string.
:param s: The string to check.
:param sep: A separator between (one of) the prefix(es) and the rest of the string.
:param prefix: A string or a set of strings to identify as a prefix.
:return: A tuple of the recognized prefix (if any) and the rest of the string excluding the separator (or the entire string).
"""
if isinstance(prefix, str):
if s.startswith(prefix + sep):
return prefix, s[len(prefix) + len(sep) :]
else:
return None, s
for p in prefix:
if s.startswith(p + sep):
return p, s[len(p) + len(sep) :]
return None, s
def _get_annotation_type(annotation: Union[type, str], callable: Callable) -> type:
"Maps a stringized reference to a type, as if using `from __future__ import annotations`."
if isinstance(annotation, str):
return eval(annotation, callable.__globals__)
else:
return annotation
class HTTPMethod(enum.Enum):
"HTTP method used to invoke an endpoint operation."
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
OperationParameter = Tuple[str, type]
class ValidationError(TypeError):
pass
@dataclass
class EndpointOperation:
"""
Type information and metadata associated with an endpoint operation.
"param defining_class: The most specific class that defines the endpoint operation.
:param name: The short name of the endpoint operation.
:param func_name: The name of the function to invoke when the operation is triggered.
:param func_ref: The callable to invoke when the operation is triggered.
:param route: A custom route string assigned to the operation.
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
:param public: True if the operation can be invoked without prior authentication.
:param request_examples: Sample requests that the operation might take.
:param response_examples: Sample responses that the operation might produce.
"""
defining_class: type
name: str
func_name: str
func_ref: Callable[..., Any]
route: Optional[str]
path_params: List[OperationParameter]
query_params: List[OperationParameter]
request_params: Optional[OperationParameter]
event_type: Optional[type]
response_type: type
http_method: HTTPMethod
public: bool
request_examples: Optional[List[Any]] = None
response_examples: Optional[List[Any]] = None
def get_route(self) -> str:
if self.route is not None:
return self.route
route_parts = ["", self.name]
for param_name, _ in self.path_params:
route_parts.append("{" + param_name + "}")
return "/".join(route_parts)
class _FormatParameterExtractor:
"A visitor to exract parameters in a format string."
keys: List[str]
def __init__(self) -> None:
self.keys = []
def __getitem__(self, key: str) -> None:
self.keys.append(key)
return None
def _get_route_parameters(route: str) -> List[str]:
extractor = _FormatParameterExtractor()
route.format_map(extractor)
return extractor.keys
def _get_endpoint_functions(
endpoint: type, prefixes: List[str]
) -> Iterator[Tuple[str, str, str, Callable]]:
if not inspect.isclass(endpoint):
raise ValueError(f"object is not a class type: {endpoint}")
functions = inspect.getmembers(endpoint, inspect.isfunction)
for func_name, func_ref in functions:
webmethod = getattr(func_ref, "__webmethod__", None)
if not webmethod:
continue
print(f"Processing {colored(func_name, 'white')}...")
operation_name = func_name
if operation_name.startswith("get_") or operation_name.endswith("/get"):
prefix = "get"
elif (
operation_name.startswith("delete_")
or operation_name.startswith("remove_")
or operation_name.endswith("/delete")
or operation_name.endswith("/remove")
):
prefix = "delete"
else:
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
prefix = "delete"
else:
# by default everything else is a POST
prefix = "post"
yield prefix, operation_name, func_name, func_ref
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
return cls
raise ValidationError(
f"cannot find defining class for {member_fn} in {derived_cls}"
)
def get_endpoint_operations(
endpoint: type, use_examples: bool = True
) -> List[EndpointOperation]:
"""
Extracts a list of member functions in a class eligible for HTTP interface binding.
These member functions are expected to have a signature like
```
async def get_object(self, uuid: str, version: int) -> Object:
...
```
where the prefix `get_` translates to an HTTP GET, `object` corresponds to the name of the endpoint operation,
`uuid` and `version` are mapped to route path elements in "/object/{uuid}/{version}", and `Object` becomes
the response payload type, transmitted as an object serialized to JSON.
If the member function has a composite class type in the argument list, it becomes the request payload type,
and the caller is expected to provide the data as serialized JSON in an HTTP POST request.
:param endpoint: A class with member functions that can be mapped to an HTTP endpoint.
:param use_examples: Whether to return examples associated with member functions.
"""
result = []
for prefix, operation_name, func_name, func_ref in _get_endpoint_functions(
endpoint,
[
"create",
"delete",
"do",
"get",
"post",
"put",
"remove",
"set",
"update",
],
):
# extract routing information from function metadata
webmethod = getattr(func_ref, "__webmethod__", None)
if webmethod is not None:
route = webmethod.route
route_params = _get_route_parameters(route) if route is not None else None
public = webmethod.public
request_examples = webmethod.request_examples
response_examples = webmethod.response_examples
else:
route = None
route_params = None
public = False
request_examples = None
response_examples = None
# inspect function signature for path and query parameters, and request/response payload type
signature = get_signature(func_ref)
path_params = []
query_params = []
request_params = []
for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref)
# omit "self" for instance methods
if param_name == "self" and param_type is inspect.Parameter.empty:
continue
# check if all parameters have explicit type
if parameter.annotation is inspect.Parameter.empty:
raise ValidationError(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
else:
inner_type = param_type
if prefix == "get" and (
inner_type is bool
or inner_type is int
or inner_type is float
or inner_type is str
or inner_type is uuid.UUID
or is_type_enum(inner_type)
):
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
if route_params is not None and param_name not in route_params:
raise ValidationError(
f"positional parameter '{param_name}' absent from user-defined route '{route}' for function '{func_name}'"
)
# simple type maps to route path element, e.g. /study/{uuid}/{version}
path_params.append((param_name, param_type))
else:
if route_params is not None and param_name in route_params:
raise ValidationError(
f"query parameter '{param_name}' found in user-defined route '{route}' for function '{func_name}'"
)
# simple type maps to key=value pair in query string
query_params.append((param_name, param_type))
else:
if route_params is not None and param_name in route_params:
raise ValidationError(
f"user-defined route '{route}' for function '{func_name}' has parameter '{param_name}' of composite type: {param_type}"
)
request_params.append((param_name, param_type))
# check if function has explicit return type
if signature.return_annotation is inspect.Signature.empty:
raise ValidationError(
f"function '{func_name}' has no return type annotation"
)
return_type = _get_annotation_type(signature.return_annotation, func_ref)
# operations that produce events are labeled as Generator[YieldType, SendType, ReturnType]
# where YieldType is the event type, SendType is None, and ReturnType is the immediate response type to the request
if typing.get_origin(return_type) is collections.abc.Generator:
event_type, send_type, response_type = typing.get_args(return_type)
if send_type is not type(None):
raise ValidationError(
f"function '{func_name}' has a return type Generator[Y,S,R] and therefore looks like an event but has an explicit send type"
)
else:
event_type = None
response_type = return_type
# set HTTP request method based on type of request and presence of payload
if not request_params:
if prefix in ["delete", "remove"]:
http_method = HTTPMethod.DELETE
else:
http_method = HTTPMethod.GET
else:
if prefix == "set":
http_method = HTTPMethod.PUT
elif prefix == "update":
http_method = HTTPMethod.PATCH
else:
http_method = HTTPMethod.POST
result.append(
EndpointOperation(
defining_class=_get_defining_class(func_name, endpoint),
name=operation_name,
func_name=func_name,
func_ref=func_ref,
route=route,
path_params=path_params,
query_params=query_params,
request_params=request_params,
event_type=event_type,
response_type=response_type,
http_method=http_method,
public=public,
request_examples=request_examples if use_examples else None,
response_examples=response_examples if use_examples else None,
)
)
if not result:
raise ValidationError(f"no eligible endpoint operations in type {endpoint}")
return result
def get_endpoint_events(endpoint: type) -> Dict[str, type]:
results = {}
for decl in typing.get_type_hints(endpoint).values():
# check if signature is Callable[...]
origin = typing.get_origin(decl)
if origin is None or not issubclass(origin, Callable): # type: ignore
continue
# check if signature is Callable[[...], Any]
args = typing.get_args(decl)
if len(args) != 2:
continue
params_type, return_type = args
if not isinstance(params_type, list):
continue
# check if signature is Callable[[...], None]
if not issubclass(return_type, type(None)):
continue
# check if signature is Callable[[EventType], None]
if len(params_type) != 1:
continue
param_type = params_type[0]
results[param_type.__name__] = param_type
return results

View file

@ -1,75 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import dataclasses
from dataclasses import dataclass
from http import HTTPStatus
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
from .specification import (
Info,
SecurityScheme,
SecuritySchemeAPI,
SecuritySchemeHTTP,
SecuritySchemeOpenIDConnect,
Server,
)
HTTPStatusCode = Union[HTTPStatus, int, str]
@dataclass
class Options:
"""
:param server: Base URL for the API endpoint.
:param info: Meta-information for the endpoint specification.
:param version: OpenAPI specification version as a tuple of major, minor, revision.
:param default_security_scheme: Security scheme to apply to endpoints, unless overridden on a per-endpoint basis.
:param extra_types: Extra types in addition to those found in operation signatures. Use a dictionary to group related types.
:param use_examples: Whether to emit examples for operations.
:param success_responses: Associates operation response types with HTTP status codes.
:param error_responses: Associates error response types with HTTP status codes.
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
"""
server: Server
info: Info
version: Tuple[int, int, int] = (3, 1, 0)
default_security_scheme: Optional[SecurityScheme] = None
extra_types: Union[List[type], Dict[str, List[type]], None] = None
use_examples: bool = True
success_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
default_factory=dict
)
error_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
default_factory=dict
)
error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None
default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations",
"Types": "Types",
"Events": "Events",
"AdditionalTypes": "Additional types",
}
def map(self, id: str) -> str:
"Maps a language-neutral placeholder string to language-dependent text."
if self.captions is not None:
caption = self.captions.get(id)
if caption is not None:
return caption
caption = self.__class__.default_captions.get(id)
if caption is not None:
return caption
raise KeyError(f"no caption found for ID: {id}")

View file

@ -1,258 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import dataclasses
import enum
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Union
from strong_typing.schema import JsonType, Schema, StrictJsonType
URL = str
@dataclass
class Ref:
ref_type: ClassVar[str]
id: str
def to_json(self) -> StrictJsonType:
return {"$ref": f"#/components/{self.ref_type}/{self.id}"}
@dataclass
class SchemaRef(Ref):
ref_type: ClassVar[str] = "schemas"
SchemaOrRef = Union[Schema, SchemaRef]
@dataclass
class ResponseRef(Ref):
ref_type: ClassVar[str] = "responses"
@dataclass
class ParameterRef(Ref):
ref_type: ClassVar[str] = "parameters"
@dataclass
class ExampleRef(Ref):
ref_type: ClassVar[str] = "examples"
@dataclass
class Contact:
name: Optional[str] = None
url: Optional[URL] = None
email: Optional[str] = None
@dataclass
class License:
name: str
url: Optional[URL] = None
@dataclass
class Info:
title: str
version: str
description: Optional[str] = None
termsOfService: Optional[str] = None
contact: Optional[Contact] = None
license: Optional[License] = None
@dataclass
class MediaType:
schema: Optional[SchemaOrRef] = None
example: Optional[Any] = None
examples: Optional[Dict[str, Union["Example", ExampleRef]]] = None
@dataclass
class RequestBody:
content: Dict[str, MediaType]
description: Optional[str] = None
required: Optional[bool] = None
@dataclass
class Response:
description: str
content: Optional[Dict[str, MediaType]] = None
class ParameterLocation(enum.Enum):
Query = "query"
Header = "header"
Path = "path"
Cookie = "cookie"
@dataclass
class Parameter:
name: str
in_: ParameterLocation
description: Optional[str] = None
required: Optional[bool] = None
schema: Optional[SchemaOrRef] = None
example: Optional[Any] = None
@dataclass
class Operation:
responses: Dict[str, Union[Response, ResponseRef]]
tags: Optional[List[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[List[Parameter]] = None
requestBody: Optional[RequestBody] = None
callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None
@dataclass
class PathItem:
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
put: Optional[Operation] = None
post: Optional[Operation] = None
delete: Optional[Operation] = None
options: Optional[Operation] = None
head: Optional[Operation] = None
patch: Optional[Operation] = None
trace: Optional[Operation] = None
def update(self, other: "PathItem") -> None:
"Merges another instance of this class into this object."
for field in dataclasses.fields(self.__class__):
value = getattr(other, field.name)
if value is not None:
setattr(self, field.name, value)
# maps run-time expressions such as "$request.body#/url" to path items
Callback = Dict[str, PathItem]
@dataclass
class Example:
summary: Optional[str] = None
description: Optional[str] = None
value: Optional[Any] = None
externalValue: Optional[URL] = None
@dataclass
class Server:
url: URL
description: Optional[str] = None
class SecuritySchemeType(enum.Enum):
ApiKey = "apiKey"
HTTP = "http"
OAuth2 = "oauth2"
OpenIDConnect = "openIdConnect"
@dataclass
class SecurityScheme:
type: SecuritySchemeType
description: str
@dataclass(init=False)
class SecuritySchemeAPI(SecurityScheme):
name: str
in_: ParameterLocation
def __init__(self, description: str, name: str, in_: ParameterLocation) -> None:
super().__init__(SecuritySchemeType.ApiKey, description)
self.name = name
self.in_ = in_
@dataclass(init=False)
class SecuritySchemeHTTP(SecurityScheme):
scheme: str
bearerFormat: Optional[str] = None
def __init__(
self, description: str, scheme: str, bearerFormat: Optional[str] = None
) -> None:
super().__init__(SecuritySchemeType.HTTP, description)
self.scheme = scheme
self.bearerFormat = bearerFormat
@dataclass(init=False)
class SecuritySchemeOpenIDConnect(SecurityScheme):
openIdConnectUrl: str
def __init__(self, description: str, openIdConnectUrl: str) -> None:
super().__init__(SecuritySchemeType.OpenIDConnect, description)
self.openIdConnectUrl = openIdConnectUrl
@dataclass
class Components:
schemas: Optional[Dict[str, Schema]] = None
responses: Optional[Dict[str, Response]] = None
parameters: Optional[Dict[str, Parameter]] = None
examples: Optional[Dict[str, Example]] = None
requestBodies: Optional[Dict[str, RequestBody]] = None
securitySchemes: Optional[Dict[str, SecurityScheme]] = None
callbacks: Optional[Dict[str, Callback]] = None
SecurityScope = str
SecurityRequirement = Dict[str, List[SecurityScope]]
@dataclass
class Tag:
name: str
description: Optional[str] = None
displayName: Optional[str] = None
@dataclass
class TagGroup:
"""
A ReDoc extension to provide information about groups of tags.
Exposed via the vendor-specific property "x-tagGroups" of the top-level object.
"""
name: str
tags: List[str]
@dataclass
class Document:
"""
This class is a Python dataclass adaptation of the OpenAPI Specification.
For details, see <https://swagger.io/specification/>
"""
openapi: str
info: Info
servers: List[Server]
paths: Dict[str, PathItem]
jsonSchemaDialect: Optional[str] = None
components: Optional[Components] = None
security: Optional[List[SecurityRequirement]] = None
tags: Optional[List[Tag]] = None
tagGroups: Optional[List[TagGroup]] = None

View file

@ -1,41 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>OpenAPI specification</title>
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<style>
body {
margin: 0;
padding: 0;
}
</style>
<script defer="defer" src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
<script defer="defer">
document.addEventListener("DOMContentLoaded", function () {
spec = { /* OPENAPI_SPECIFICATION */ };
options = {
downloadFileName: "openapi.json",
expandResponses: "200",
expandSingleSchemaField: true,
jsonSampleExpandLevel: "all",
schemaExpansionLevel: "all",
};
element = document.getElementById("openapi-container");
Redoc.init(spec, options, element);
if (spec.info && spec.info.title) {
document.title = spec.info.title;
}
});
</script>
</head>
<body>
<div id="openapi-container"></div>
</body>
</html>

View file

@ -1,116 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import typing
from pathlib import Path
from typing import TextIO
from strong_typing.schema import object_to_json, StrictJsonType
from .generator import Generator
from .options import Options
from .specification import Document
THIS_DIR = Path(__file__).parent
class Specification:
document: Document
def __init__(self, endpoint: type, options: Options):
generator = Generator(endpoint, options)
self.document = generator.generate()
def get_json(self) -> StrictJsonType:
"""
Returns the OpenAPI specification as a Python data type (e.g. `dict` for an object, `list` for an array).
The result can be serialized to a JSON string with `json.dump` or `json.dumps`.
"""
json_doc = typing.cast(StrictJsonType, object_to_json(self.document))
if isinstance(json_doc, dict):
# rename vendor-specific properties
tag_groups = json_doc.pop("tagGroups", None)
if tag_groups:
json_doc["x-tagGroups"] = tag_groups
tags = json_doc.get("tags")
if tags and isinstance(tags, list):
for tag in tags:
if not isinstance(tag, dict):
continue
display_name = tag.pop("displayName", None)
if display_name:
tag["x-displayName"] = display_name
return json_doc
def get_json_string(self, pretty_print: bool = False) -> str:
"""
Returns the OpenAPI specification as a JSON string.
:param pretty_print: Whether to use line indents to beautify the output.
"""
json_doc = self.get_json()
if pretty_print:
return json.dumps(
json_doc, check_circular=False, ensure_ascii=False, indent=4
)
else:
return json.dumps(
json_doc,
check_circular=False,
ensure_ascii=False,
separators=(",", ":"),
)
def write_json(self, f: TextIO, pretty_print: bool = False) -> None:
"""
Writes the OpenAPI specification to a file as a JSON string.
:param pretty_print: Whether to use line indents to beautify the output.
"""
json_doc = self.get_json()
if pretty_print:
json.dump(
json_doc,
f,
check_circular=False,
ensure_ascii=False,
indent=4,
)
else:
json.dump(
json_doc,
f,
check_circular=False,
ensure_ascii=False,
separators=(",", ":"),
)
def write_html(self, f: TextIO, pretty_print: bool = False) -> None:
"""
Creates a stand-alone HTML page for the OpenAPI specification with ReDoc.
:param pretty_print: Whether to use line indents to beautify the JSON string in the HTML file.
"""
path = THIS_DIR / "template.html"
with path.open(encoding="utf-8", errors="strict") as html_template_file:
html_template = html_template_file.read()
html = html_template.replace(
"{ /* OPENAPI_SPECIFICATION */ }",
self.get_json_string(pretty_print=pretty_print),
)
f.write(html)

View file

@ -1,31 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
PYTHONPATH=${PYTHONPATH:-}
set -euo pipefail
missing_packages=()
check_package() {
if ! pip show "$1" &>/dev/null; then
missing_packages+=("$1")
fi
}
check_package json-strong-typing
if [ ${#missing_packages[@]} -ne 0 ]; then
echo "Error: The following package(s) are not installed:"
printf " - %s\n" "${missing_packages[@]}"
echo "Please install them using:"
echo "pip install ${missing_packages[*]}"
exit 1
fi
PYTHONPATH=$PYTHONPATH:../.. python -m rfcs.openapi_generator.generate $*