forked from phoenix-oss/llama-stack-mirror
API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)
* add tools to chat completion request * use templates for generating system prompts * Moved ToolPromptFormat and jinja templates to llama_models.llama3.api * <WIP> memory changes - inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes * InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool * flesh out memory banks API * agentic loop has a RAG implementation * faiss provider implementation * memory client works * re-work tool definitions, fix FastAPI issues, fix tool regressions * fix agentic_system utils * basic RAG seems to work * small bug fixes for inline attachments * Refactor custom tool execution utilities * Bug fix, show memory retrieval steps in EventLogger * No need for api_key for Remote providers * add special unicode character ↵ to showcase newlines in model prompt templates * remove api.endpoints imports * combine datatypes.py and endpoints.py into api.py * Attachment / add TTL api * split batch_inference from inference * minor import fixes * use a single impl for ChatFormat.decode_assistant_mesage * use interleaved_text_media_as_str() utilityt * Fix api.datatypes imports * Add blobfile for tiktoken * Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly * templates take optional --format={json,function_tag} * Rag Updates * Add `api build` subcommand -- WIP * fix * build + run image seems to work * <WIP> adapters * bunch more work to make adapters work * api build works for conda now * ollama remote adapter works * Several smaller fixes to make adapters work Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight * llama distribution -> llama stack + containers (WIP) * All the new CLI for api + stack work * Make Fireworks and Together into the Adapter format * Some quick fixes to the CLI behavior to make it consistent * Updated README phew * Update cli_reference.md * llama_toolchain/distribution -> llama_toolchain/core * Add termcolor * update paths * Add a log just for consistency * chmod +x scripts * Fix api dependencies not getting added to configuration * missing import lol * Delete utils.py; move to agentic system * Support downloading of URLs for attachments for code interpreter * Simplify and generalize `llama api build` yay * Update `llama stack configure` to be very simple also * Fix stack start * Allow building an "adhoc" distribution * Remote `llama api []` subcommands * Fixes to llama stack commands and update docs * Update documentation again and add error messages to llama stack start * llama stack start -> llama stack run * Change name of build for less confusion * Add pyopenapi fork to the repository, update RFC assets * Remove conflicting annotation * Added a "--raw" option for model template printing --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com> Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
parent
35093c0b6f
commit
7bc7785b0d
141 changed files with 8252 additions and 4032 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -10,81 +10,39 @@
|
|||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Tuple
|
||||
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
from llama_models import schema_utils
|
||||
from pyopenapi import Info, operations, Options, Server, Specification
|
||||
|
||||
# We do a series of monkey-patching to ensure our definitions only use the minimal
|
||||
# We do some monkey-patching to ensure our definitions only use the minimal
|
||||
# (json_schema_type, webmethod) definitions from the llama_models package. For
|
||||
# generation though, we need the full definitions and implementations from the
|
||||
# (python-openapi, json-strong-typing) packages.
|
||||
# (json-strong-typing) package.
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
from termcolor import colored
|
||||
|
||||
from .pyopenapi.options import Options
|
||||
from .pyopenapi.specification import Info, Server
|
||||
from .pyopenapi.utility import Specification
|
||||
|
||||
schema_utils.json_schema_type = json_schema_type
|
||||
|
||||
|
||||
from llama_toolchain.stack import LlamaStack
|
||||
|
||||
|
||||
STREAMING_ENDPOINTS = [
|
||||
"/agentic_system/turn/create"
|
||||
]
|
||||
|
||||
|
||||
def patched_get_endpoint_functions(
|
||||
endpoint: type, prefixes: List[str]
|
||||
) -> Iterator[Tuple[str, str, str, Callable]]:
|
||||
if not inspect.isclass(endpoint):
|
||||
raise ValueError(f"object is not a class type: {endpoint}")
|
||||
|
||||
functions = inspect.getmembers(endpoint, inspect.isfunction)
|
||||
for func_name, func_ref in functions:
|
||||
webmethod = getattr(func_ref, "__webmethod__", None)
|
||||
if not webmethod:
|
||||
continue
|
||||
|
||||
print(f"Processing {colored(func_name, 'white')}...")
|
||||
operation_name = func_name
|
||||
if operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
prefix = "get"
|
||||
elif (
|
||||
operation_name.startswith("delete_")
|
||||
or operation_name.startswith("remove_")
|
||||
or operation_name.endswith("/delete")
|
||||
or operation_name.endswith("/remove")
|
||||
):
|
||||
prefix = "delete"
|
||||
else:
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
else:
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
|
||||
yield prefix, operation_name, func_name, func_ref
|
||||
|
||||
|
||||
# Patch this so all methods are correctly parsed with correct HTTP methods
|
||||
operations._get_endpoint_functions = patched_get_endpoint_functions
|
||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||
STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
|
||||
|
||||
|
||||
def patch_sse_stream_responses(spec: Specification):
|
||||
for path, path_item in spec.document.paths.items():
|
||||
if path in STREAMING_ENDPOINTS:
|
||||
content = path_item.post.responses['200'].content.pop('application/json')
|
||||
path_item.post.responses['200'].content['text/event-stream'] = content
|
||||
content = path_item.post.responses["200"].content.pop("application/json")
|
||||
path_item.post.responses["200"].content["text/event-stream"] = content
|
||||
|
||||
|
||||
def main(output_dir: str):
|
||||
|
|
1
rfcs/openapi_generator/pyopenapi/README.md
Normal file
1
rfcs/openapi_generator/pyopenapi/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
This is forked from https://github.com/hunyadi/pyopenapi
|
5
rfcs/openapi_generator/pyopenapi/__init__.py
Normal file
5
rfcs/openapi_generator/pyopenapi/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
718
rfcs/openapi_generator/pyopenapi/generator.py
Normal file
718
rfcs/openapi_generator/pyopenapi/generator.py
Normal file
|
@ -0,0 +1,718 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import typing
|
||||
from typing import Any, Dict, Set, Union
|
||||
|
||||
from strong_typing.core import JsonType
|
||||
from strong_typing.docstring import Docstring, parse_type
|
||||
from strong_typing.inspection import (
|
||||
is_generic_list,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
unwrap_generic_list,
|
||||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
from strong_typing.name import python_type_to_name
|
||||
from strong_typing.schema import (
|
||||
get_schema_identifier,
|
||||
JsonSchemaGenerator,
|
||||
register_schema,
|
||||
Schema,
|
||||
SchemaOptions,
|
||||
)
|
||||
from strong_typing.serialization import json_dump_string, object_to_json
|
||||
|
||||
from .operations import (
|
||||
EndpointOperation,
|
||||
get_endpoint_events,
|
||||
get_endpoint_operations,
|
||||
HTTPMethod,
|
||||
)
|
||||
from .options import *
|
||||
from .specification import (
|
||||
Components,
|
||||
Document,
|
||||
Example,
|
||||
ExampleRef,
|
||||
MediaType,
|
||||
Operation,
|
||||
Parameter,
|
||||
ParameterLocation,
|
||||
PathItem,
|
||||
RequestBody,
|
||||
Response,
|
||||
ResponseRef,
|
||||
SchemaOrRef,
|
||||
SchemaRef,
|
||||
Tag,
|
||||
TagGroup,
|
||||
)
|
||||
|
||||
register_schema(
|
||||
ipaddress.IPv4Address,
|
||||
schema={
|
||||
"type": "string",
|
||||
"format": "ipv4",
|
||||
"title": "IPv4 address",
|
||||
"description": "IPv4 address, according to dotted-quad ABNF syntax as defined in RFC 2673, section 3.2.",
|
||||
},
|
||||
examples=["192.0.2.0", "198.51.100.1", "203.0.113.255"],
|
||||
)
|
||||
|
||||
register_schema(
|
||||
ipaddress.IPv6Address,
|
||||
schema={
|
||||
"type": "string",
|
||||
"format": "ipv6",
|
||||
"title": "IPv6 address",
|
||||
"description": "IPv6 address, as defined in RFC 2373, section 2.2.",
|
||||
},
|
||||
examples=[
|
||||
"FEDC:BA98:7654:3210:FEDC:BA98:7654:3210",
|
||||
"1080:0:0:0:8:800:200C:417A",
|
||||
"1080::8:800:200C:417A",
|
||||
"FF01::101",
|
||||
"::1",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def http_status_to_string(status_code: HTTPStatusCode) -> str:
|
||||
"Converts an HTTP status code to a string."
|
||||
|
||||
if isinstance(status_code, HTTPStatus):
|
||||
return str(status_code.value)
|
||||
elif isinstance(status_code, int):
|
||||
return str(status_code)
|
||||
elif isinstance(status_code, str):
|
||||
return status_code
|
||||
else:
|
||||
raise TypeError("expected: HTTP status code")
|
||||
|
||||
|
||||
class SchemaBuilder:
|
||||
schema_generator: JsonSchemaGenerator
|
||||
schemas: Dict[str, Schema]
|
||||
|
||||
def __init__(self, schema_generator: JsonSchemaGenerator) -> None:
|
||||
self.schema_generator = schema_generator
|
||||
self.schemas = {}
|
||||
|
||||
def classdef_to_schema(self, typ: type) -> Schema:
|
||||
"""
|
||||
Converts a type to a JSON schema.
|
||||
For nested types found in the type hierarchy, adds the type to the schema registry in the OpenAPI specification section `components`.
|
||||
"""
|
||||
|
||||
type_schema, type_definitions = self.schema_generator.classdef_to_schema(typ)
|
||||
|
||||
# append schema to list of known schemas, to be used in OpenAPI's Components Object section
|
||||
for ref, schema in type_definitions.items():
|
||||
self._add_ref(ref, schema)
|
||||
|
||||
return type_schema
|
||||
|
||||
def classdef_to_named_schema(self, name: str, typ: type) -> Schema:
|
||||
schema = self.classdef_to_schema(typ)
|
||||
self._add_ref(name, schema)
|
||||
return schema
|
||||
|
||||
def classdef_to_ref(self, typ: type) -> SchemaOrRef:
|
||||
"""
|
||||
Converts a type to a JSON schema, and if possible, returns a schema reference.
|
||||
For composite types (such as classes), adds the type to the schema registry in the OpenAPI specification section `components`.
|
||||
"""
|
||||
|
||||
type_schema = self.classdef_to_schema(typ)
|
||||
if typ is str or typ is int or typ is float:
|
||||
# represent simple types as themselves
|
||||
return type_schema
|
||||
|
||||
type_name = get_schema_identifier(typ)
|
||||
if type_name is not None:
|
||||
return self._build_ref(type_name, type_schema)
|
||||
|
||||
try:
|
||||
type_name = python_type_to_name(typ)
|
||||
return self._build_ref(type_name, type_schema)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
return type_schema
|
||||
|
||||
def _build_ref(self, type_name: str, type_schema: Schema) -> SchemaRef:
|
||||
self._add_ref(type_name, type_schema)
|
||||
return SchemaRef(type_name)
|
||||
|
||||
def _add_ref(self, type_name: str, type_schema: Schema) -> None:
|
||||
if type_name not in self.schemas:
|
||||
self.schemas[type_name] = type_schema
|
||||
|
||||
|
||||
class ContentBuilder:
|
||||
schema_builder: SchemaBuilder
|
||||
schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]]
|
||||
sample_transformer: Optional[Callable[[JsonType], JsonType]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema_builder: SchemaBuilder,
|
||||
schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]] = None,
|
||||
sample_transformer: Optional[Callable[[JsonType], JsonType]] = None,
|
||||
) -> None:
|
||||
self.schema_builder = schema_builder
|
||||
self.schema_transformer = schema_transformer
|
||||
self.sample_transformer = sample_transformer
|
||||
|
||||
def build_content(
|
||||
self, payload_type: type, examples: Optional[List[Any]] = None
|
||||
) -> Dict[str, MediaType]:
|
||||
"Creates the content subtree for a request or response."
|
||||
|
||||
if is_generic_list(payload_type):
|
||||
media_type = "application/jsonl"
|
||||
item_type = unwrap_generic_list(payload_type)
|
||||
else:
|
||||
media_type = "application/json"
|
||||
item_type = payload_type
|
||||
|
||||
return {media_type: self.build_media_type(item_type, examples)}
|
||||
|
||||
def build_media_type(
|
||||
self, item_type: type, examples: Optional[List[Any]] = None
|
||||
) -> MediaType:
|
||||
schema = self.schema_builder.classdef_to_ref(item_type)
|
||||
if self.schema_transformer:
|
||||
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = self.schema_transformer # type: ignore
|
||||
schema = schema_transformer(schema)
|
||||
|
||||
if not examples:
|
||||
return MediaType(schema=schema)
|
||||
|
||||
if len(examples) == 1:
|
||||
return MediaType(schema=schema, example=self._build_example(examples[0]))
|
||||
|
||||
return MediaType(
|
||||
schema=schema,
|
||||
examples=self._build_examples(examples),
|
||||
)
|
||||
|
||||
def _build_examples(
|
||||
self, examples: List[Any]
|
||||
) -> Dict[str, Union[Example, ExampleRef]]:
|
||||
"Creates a set of several examples for a media type."
|
||||
|
||||
if self.sample_transformer:
|
||||
sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer # type: ignore
|
||||
else:
|
||||
sample_transformer = lambda sample: sample
|
||||
|
||||
results: Dict[str, Union[Example, ExampleRef]] = {}
|
||||
for example in examples:
|
||||
value = sample_transformer(object_to_json(example))
|
||||
|
||||
hash_string = (
|
||||
hashlib.md5(json_dump_string(value).encode("utf-8")).digest().hex()
|
||||
)
|
||||
name = f"ex-{hash_string}"
|
||||
|
||||
results[name] = Example(value=value)
|
||||
|
||||
return results
|
||||
|
||||
def _build_example(self, example: Any) -> Any:
|
||||
"Creates a single example for a media type."
|
||||
|
||||
if self.sample_transformer:
|
||||
sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer # type: ignore
|
||||
else:
|
||||
sample_transformer = lambda sample: sample
|
||||
|
||||
return sample_transformer(object_to_json(example))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseOptions:
|
||||
"""
|
||||
Configuration options for building a response for an operation.
|
||||
|
||||
:param type_descriptions: Maps each response type to a textual description (if available).
|
||||
:param examples: A list of response examples.
|
||||
:param status_catalog: Maps each response type to an HTTP status code.
|
||||
:param default_status_code: HTTP status code assigned to responses that have no mapping.
|
||||
"""
|
||||
|
||||
type_descriptions: Dict[type, str]
|
||||
examples: Optional[List[Any]]
|
||||
status_catalog: Dict[type, HTTPStatusCode]
|
||||
default_status_code: HTTPStatusCode
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusResponse:
|
||||
status_code: str
|
||||
types: List[type] = dataclasses.field(default_factory=list)
|
||||
examples: List[Any] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseBuilder:
|
||||
content_builder: ContentBuilder
|
||||
|
||||
def __init__(self, content_builder: ContentBuilder) -> None:
|
||||
self.content_builder = content_builder
|
||||
|
||||
def _get_status_responses(
|
||||
self, options: ResponseOptions
|
||||
) -> Dict[str, StatusResponse]:
|
||||
status_responses: Dict[str, StatusResponse] = {}
|
||||
|
||||
for response_type in options.type_descriptions.keys():
|
||||
status_code = http_status_to_string(
|
||||
options.status_catalog.get(response_type, options.default_status_code)
|
||||
)
|
||||
|
||||
# look up response for status code
|
||||
if status_code not in status_responses:
|
||||
status_responses[status_code] = StatusResponse(status_code)
|
||||
status_response = status_responses[status_code]
|
||||
|
||||
# append response types that are assigned the given status code
|
||||
status_response.types.append(response_type)
|
||||
|
||||
# append examples that have the matching response type
|
||||
if options.examples:
|
||||
status_response.examples.extend(
|
||||
example
|
||||
for example in options.examples
|
||||
if isinstance(example, response_type)
|
||||
)
|
||||
|
||||
return dict(sorted(status_responses.items()))
|
||||
|
||||
def build_response(
|
||||
self, options: ResponseOptions
|
||||
) -> Dict[str, Union[Response, ResponseRef]]:
|
||||
"""
|
||||
Groups responses that have the same status code.
|
||||
"""
|
||||
|
||||
responses: Dict[str, Union[Response, ResponseRef]] = {}
|
||||
status_responses = self._get_status_responses(options)
|
||||
for status_code, status_response in status_responses.items():
|
||||
response_types = tuple(status_response.types)
|
||||
if len(response_types) > 1:
|
||||
composite_response_type: type = Union[response_types] # type: ignore
|
||||
else:
|
||||
(response_type,) = response_types
|
||||
composite_response_type = response_type
|
||||
|
||||
description = " **OR** ".join(
|
||||
filter(
|
||||
None,
|
||||
(
|
||||
options.type_descriptions[response_type]
|
||||
for response_type in response_types
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
responses[status_code] = self._build_response(
|
||||
response_type=composite_response_type,
|
||||
description=description,
|
||||
examples=status_response.examples or None,
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
response_type: type,
|
||||
description: str,
|
||||
examples: Optional[List[Any]] = None,
|
||||
) -> Response:
|
||||
"Creates a response subtree."
|
||||
|
||||
if response_type is not None:
|
||||
return Response(
|
||||
description=description,
|
||||
content=self.content_builder.build_content(response_type, examples),
|
||||
)
|
||||
else:
|
||||
return Response(description=description)
|
||||
|
||||
|
||||
def schema_error_wrapper(schema: SchemaOrRef) -> Schema:
|
||||
"Wraps an error output schema into a top-level error schema."
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": schema, # type: ignore
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": [
|
||||
"error",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def sample_error_wrapper(error: JsonType) -> JsonType:
|
||||
"Wraps an error output sample into a top-level error sample."
|
||||
|
||||
return {"error": error}
|
||||
|
||||
|
||||
class Generator:
|
||||
endpoint: type
|
||||
options: Options
|
||||
schema_builder: SchemaBuilder
|
||||
responses: Dict[str, Response]
|
||||
|
||||
def __init__(self, endpoint: type, options: Options) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.options = options
|
||||
schema_generator = JsonSchemaGenerator(
|
||||
SchemaOptions(
|
||||
definitions_path="#/components/schemas/",
|
||||
use_examples=self.options.use_examples,
|
||||
property_description_fun=options.property_description_fun,
|
||||
)
|
||||
)
|
||||
self.schema_builder = SchemaBuilder(schema_generator)
|
||||
self.responses = {}
|
||||
|
||||
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
|
||||
definition = f'<SchemaDefinition schemaRef="#/components/schemas/{ref}" />'
|
||||
title = typing.cast(str, schema.get("title"))
|
||||
description = typing.cast(str, schema.get("description"))
|
||||
return Tag(
|
||||
name=ref,
|
||||
description="\n\n".join(
|
||||
s for s in (title, description, definition) if s is not None
|
||||
),
|
||||
)
|
||||
|
||||
def _build_extra_tag_groups(
|
||||
self, extra_types: Dict[str, List[type]]
|
||||
) -> Dict[str, List[Tag]]:
|
||||
"""
|
||||
Creates a dictionary of tag group captions as keys, and tag lists as values.
|
||||
|
||||
:param extra_types: A dictionary of type categories and list of types in that category.
|
||||
"""
|
||||
|
||||
extra_tags: Dict[str, List[Tag]] = {}
|
||||
|
||||
for category_name, category_items in extra_types.items():
|
||||
tag_list: List[Tag] = []
|
||||
|
||||
for extra_type in category_items:
|
||||
name = python_type_to_name(extra_type)
|
||||
schema = self.schema_builder.classdef_to_named_schema(name, extra_type)
|
||||
tag_list.append(self._build_type_tag(name, schema))
|
||||
|
||||
if tag_list:
|
||||
extra_tags[category_name] = tag_list
|
||||
|
||||
return extra_tags
|
||||
|
||||
def _build_operation(self, op: EndpointOperation) -> Operation:
|
||||
doc_string = parse_type(op.func_ref)
|
||||
doc_params = dict(
|
||||
(param.name, param.description) for param in doc_string.params.values()
|
||||
)
|
||||
|
||||
# parameters passed in URL component path
|
||||
path_parameters = [
|
||||
Parameter(
|
||||
name=param_name,
|
||||
in_=ParameterLocation.Path,
|
||||
description=doc_params.get(param_name),
|
||||
required=True,
|
||||
schema=self.schema_builder.classdef_to_ref(param_type),
|
||||
)
|
||||
for param_name, param_type in op.path_params
|
||||
]
|
||||
|
||||
# parameters passed in URL component query string
|
||||
query_parameters = []
|
||||
for param_name, param_type in op.query_params:
|
||||
if is_type_optional(param_type):
|
||||
inner_type: type = unwrap_optional_type(param_type)
|
||||
required = False
|
||||
else:
|
||||
inner_type = param_type
|
||||
required = True
|
||||
|
||||
query_parameter = Parameter(
|
||||
name=param_name,
|
||||
in_=ParameterLocation.Query,
|
||||
description=doc_params.get(param_name),
|
||||
required=required,
|
||||
schema=self.schema_builder.classdef_to_ref(inner_type),
|
||||
)
|
||||
query_parameters.append(query_parameter)
|
||||
|
||||
# parameters passed anywhere
|
||||
parameters = path_parameters + query_parameters
|
||||
|
||||
# data passed in payload
|
||||
if op.request_params:
|
||||
builder = ContentBuilder(self.schema_builder)
|
||||
if len(op.request_params) == 1:
|
||||
request_name, request_type = op.request_params[0]
|
||||
else:
|
||||
from dataclasses import make_dataclass
|
||||
|
||||
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
||||
request_name = f"{op_name}Request"
|
||||
request_type = make_dataclass(request_name, op.request_params)
|
||||
|
||||
requestBody = RequestBody(
|
||||
content={
|
||||
"application/json": builder.build_media_type(
|
||||
request_type, op.request_examples
|
||||
)
|
||||
},
|
||||
description=doc_params.get(request_name),
|
||||
required=True,
|
||||
)
|
||||
else:
|
||||
requestBody = None
|
||||
|
||||
# success response types
|
||||
if doc_string.returns is None and is_type_union(op.response_type):
|
||||
# split union of return types into a list of response types
|
||||
success_type_docstring: Dict[type, Docstring] = {
|
||||
typing.cast(type, item): parse_type(item)
|
||||
for item in unwrap_union_types(op.response_type)
|
||||
}
|
||||
success_type_descriptions = {
|
||||
item: doc_string.short_description
|
||||
for item, doc_string in success_type_docstring.items()
|
||||
if doc_string.short_description
|
||||
}
|
||||
else:
|
||||
# use return type as a single response type
|
||||
success_type_descriptions = {
|
||||
op.response_type: (
|
||||
doc_string.returns.description if doc_string.returns else "OK"
|
||||
)
|
||||
}
|
||||
|
||||
response_examples = op.response_examples or []
|
||||
success_examples = [
|
||||
example
|
||||
for example in response_examples
|
||||
if not isinstance(example, Exception)
|
||||
]
|
||||
|
||||
content_builder = ContentBuilder(self.schema_builder)
|
||||
response_builder = ResponseBuilder(content_builder)
|
||||
response_options = ResponseOptions(
|
||||
success_type_descriptions,
|
||||
success_examples if self.options.use_examples else None,
|
||||
self.options.success_responses,
|
||||
"200",
|
||||
)
|
||||
responses = response_builder.build_response(response_options)
|
||||
|
||||
# failure response types
|
||||
if doc_string.raises:
|
||||
exception_types: Dict[type, str] = {
|
||||
item.raise_type: item.description for item in doc_string.raises.values()
|
||||
}
|
||||
exception_examples = [
|
||||
example
|
||||
for example in response_examples
|
||||
if isinstance(example, Exception)
|
||||
]
|
||||
|
||||
if self.options.error_wrapper:
|
||||
schema_transformer = schema_error_wrapper
|
||||
sample_transformer = sample_error_wrapper
|
||||
else:
|
||||
schema_transformer = None
|
||||
sample_transformer = None
|
||||
|
||||
content_builder = ContentBuilder(
|
||||
self.schema_builder,
|
||||
schema_transformer=schema_transformer,
|
||||
sample_transformer=sample_transformer,
|
||||
)
|
||||
response_builder = ResponseBuilder(content_builder)
|
||||
response_options = ResponseOptions(
|
||||
exception_types,
|
||||
exception_examples if self.options.use_examples else None,
|
||||
self.options.error_responses,
|
||||
"500",
|
||||
)
|
||||
responses.update(response_builder.build_response(response_options))
|
||||
|
||||
if op.event_type is not None:
|
||||
builder = ContentBuilder(self.schema_builder)
|
||||
callbacks = {
|
||||
f"{op.func_name}_callback": {
|
||||
"{$request.query.callback}": PathItem(
|
||||
post=Operation(
|
||||
requestBody=RequestBody(
|
||||
content=builder.build_content(op.event_type)
|
||||
),
|
||||
responses={"200": Response(description="OK")},
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
else:
|
||||
callbacks = None
|
||||
|
||||
return Operation(
|
||||
tags=[op.defining_class.__name__],
|
||||
summary=doc_string.short_description,
|
||||
description=doc_string.long_description,
|
||||
parameters=parameters,
|
||||
requestBody=requestBody,
|
||||
responses=responses,
|
||||
callbacks=callbacks,
|
||||
security=[] if op.public else None,
|
||||
)
|
||||
|
||||
def generate(self) -> Document:
|
||||
paths: Dict[str, PathItem] = {}
|
||||
endpoint_classes: Set[type] = set()
|
||||
for op in get_endpoint_operations(
|
||||
self.endpoint, use_examples=self.options.use_examples
|
||||
):
|
||||
endpoint_classes.add(op.defining_class)
|
||||
|
||||
operation = self._build_operation(op)
|
||||
|
||||
if op.http_method is HTTPMethod.GET:
|
||||
pathItem = PathItem(get=operation)
|
||||
elif op.http_method is HTTPMethod.PUT:
|
||||
pathItem = PathItem(put=operation)
|
||||
elif op.http_method is HTTPMethod.POST:
|
||||
pathItem = PathItem(post=operation)
|
||||
elif op.http_method is HTTPMethod.DELETE:
|
||||
pathItem = PathItem(delete=operation)
|
||||
elif op.http_method is HTTPMethod.PATCH:
|
||||
pathItem = PathItem(patch=operation)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
||||
|
||||
route = op.get_route()
|
||||
if route in paths:
|
||||
paths[route].update(pathItem)
|
||||
else:
|
||||
paths[route] = pathItem
|
||||
|
||||
operation_tags: List[Tag] = []
|
||||
for cls in endpoint_classes:
|
||||
doc_string = parse_type(cls)
|
||||
operation_tags.append(
|
||||
Tag(
|
||||
name=cls.__name__,
|
||||
description=doc_string.long_description,
|
||||
displayName=doc_string.short_description,
|
||||
)
|
||||
)
|
||||
|
||||
# types that are produced/consumed by operations
|
||||
type_tags = [
|
||||
self._build_type_tag(ref, schema)
|
||||
for ref, schema in self.schema_builder.schemas.items()
|
||||
]
|
||||
|
||||
# types that are emitted by events
|
||||
event_tags: List[Tag] = []
|
||||
events = get_endpoint_events(self.endpoint)
|
||||
for ref, event_type in events.items():
|
||||
event_schema = self.schema_builder.classdef_to_named_schema(ref, event_type)
|
||||
event_tags.append(self._build_type_tag(ref, event_schema))
|
||||
|
||||
# types that are explicitly declared
|
||||
extra_tag_groups: Dict[str, List[Tag]] = {}
|
||||
if self.options.extra_types is not None:
|
||||
if isinstance(self.options.extra_types, list):
|
||||
extra_tag_groups = self._build_extra_tag_groups(
|
||||
{"AdditionalTypes": self.options.extra_types}
|
||||
)
|
||||
elif isinstance(self.options.extra_types, dict):
|
||||
extra_tag_groups = self._build_extra_tag_groups(
|
||||
self.options.extra_types
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"type mismatch for collection of extra types: {type(self.options.extra_types)}"
|
||||
)
|
||||
|
||||
# list all operations and types
|
||||
tags: List[Tag] = []
|
||||
tags.extend(operation_tags)
|
||||
tags.extend(type_tags)
|
||||
tags.extend(event_tags)
|
||||
for extra_tag_group in extra_tag_groups.values():
|
||||
tags.extend(extra_tag_group)
|
||||
|
||||
tag_groups = []
|
||||
if operation_tags:
|
||||
tag_groups.append(
|
||||
TagGroup(
|
||||
name=self.options.map("Operations"),
|
||||
tags=sorted(tag.name for tag in operation_tags),
|
||||
)
|
||||
)
|
||||
if type_tags:
|
||||
tag_groups.append(
|
||||
TagGroup(
|
||||
name=self.options.map("Types"),
|
||||
tags=sorted(tag.name for tag in type_tags),
|
||||
)
|
||||
)
|
||||
if event_tags:
|
||||
tag_groups.append(
|
||||
TagGroup(
|
||||
name=self.options.map("Events"),
|
||||
tags=sorted(tag.name for tag in event_tags),
|
||||
)
|
||||
)
|
||||
for caption, extra_tag_group in extra_tag_groups.items():
|
||||
tag_groups.append(
|
||||
TagGroup(
|
||||
name=self.options.map(caption),
|
||||
tags=sorted(tag.name for tag in extra_tag_group),
|
||||
)
|
||||
)
|
||||
|
||||
if self.options.default_security_scheme:
|
||||
securitySchemes = {"Default": self.options.default_security_scheme}
|
||||
else:
|
||||
securitySchemes = None
|
||||
|
||||
return Document(
|
||||
openapi=".".join(str(item) for item in self.options.version),
|
||||
info=self.options.info,
|
||||
jsonSchemaDialect=(
|
||||
"https://json-schema.org/draft/2020-12/schema"
|
||||
if self.options.version >= (3, 1, 0)
|
||||
else None
|
||||
),
|
||||
servers=[self.options.server],
|
||||
paths=paths,
|
||||
components=Components(
|
||||
schemas=self.schema_builder.schemas,
|
||||
responses=self.responses,
|
||||
securitySchemes=securitySchemes,
|
||||
),
|
||||
security=[{"Default": []}],
|
||||
tags=tags,
|
||||
tagGroups=tag_groups,
|
||||
)
|
386
rfcs/openapi_generator/pyopenapi/operations.py
Normal file
386
rfcs/openapi_generator/pyopenapi/operations.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import collections.abc
|
||||
import enum
|
||||
import inspect
|
||||
import typing
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from strong_typing.inspection import (
|
||||
get_signature,
|
||||
is_type_enum,
|
||||
is_type_optional,
|
||||
unwrap_optional_type,
|
||||
)
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def split_prefix(
|
||||
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Recognizes a prefix at the beginning of a string.
|
||||
|
||||
:param s: The string to check.
|
||||
:param sep: A separator between (one of) the prefix(es) and the rest of the string.
|
||||
:param prefix: A string or a set of strings to identify as a prefix.
|
||||
:return: A tuple of the recognized prefix (if any) and the rest of the string excluding the separator (or the entire string).
|
||||
"""
|
||||
|
||||
if isinstance(prefix, str):
|
||||
if s.startswith(prefix + sep):
|
||||
return prefix, s[len(prefix) + len(sep) :]
|
||||
else:
|
||||
return None, s
|
||||
|
||||
for p in prefix:
|
||||
if s.startswith(p + sep):
|
||||
return p, s[len(p) + len(sep) :]
|
||||
|
||||
return None, s
|
||||
|
||||
|
||||
def _get_annotation_type(annotation: Union[type, str], callable: Callable) -> type:
|
||||
"Maps a stringized reference to a type, as if using `from __future__ import annotations`."
|
||||
|
||||
if isinstance(annotation, str):
|
||||
return eval(annotation, callable.__globals__)
|
||||
else:
|
||||
return annotation
|
||||
|
||||
|
||||
class HTTPMethod(enum.Enum):
|
||||
"HTTP method used to invoke an endpoint operation."
|
||||
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
OperationParameter = Tuple[str, type]
|
||||
|
||||
|
||||
class ValidationError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointOperation:
|
||||
"""
|
||||
Type information and metadata associated with an endpoint operation.
|
||||
|
||||
"param defining_class: The most specific class that defines the endpoint operation.
|
||||
:param name: The short name of the endpoint operation.
|
||||
:param func_name: The name of the function to invoke when the operation is triggered.
|
||||
:param func_ref: The callable to invoke when the operation is triggered.
|
||||
:param route: A custom route string assigned to the operation.
|
||||
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
|
||||
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
||||
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
||||
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
||||
:param response_type: The Python type of the data that is transmitted in the response body.
|
||||
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
||||
:param public: True if the operation can be invoked without prior authentication.
|
||||
:param request_examples: Sample requests that the operation might take.
|
||||
:param response_examples: Sample responses that the operation might produce.
|
||||
"""
|
||||
|
||||
defining_class: type
|
||||
name: str
|
||||
func_name: str
|
||||
func_ref: Callable[..., Any]
|
||||
route: Optional[str]
|
||||
path_params: List[OperationParameter]
|
||||
query_params: List[OperationParameter]
|
||||
request_params: Optional[OperationParameter]
|
||||
event_type: Optional[type]
|
||||
response_type: type
|
||||
http_method: HTTPMethod
|
||||
public: bool
|
||||
request_examples: Optional[List[Any]] = None
|
||||
response_examples: Optional[List[Any]] = None
|
||||
|
||||
def get_route(self) -> str:
|
||||
if self.route is not None:
|
||||
return self.route
|
||||
|
||||
route_parts = ["", self.name]
|
||||
for param_name, _ in self.path_params:
|
||||
route_parts.append("{" + param_name + "}")
|
||||
return "/".join(route_parts)
|
||||
|
||||
|
||||
class _FormatParameterExtractor:
|
||||
"A visitor to exract parameters in a format string."
|
||||
|
||||
keys: List[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.keys = []
|
||||
|
||||
def __getitem__(self, key: str) -> None:
|
||||
self.keys.append(key)
|
||||
return None
|
||||
|
||||
|
||||
def _get_route_parameters(route: str) -> List[str]:
|
||||
extractor = _FormatParameterExtractor()
|
||||
route.format_map(extractor)
|
||||
return extractor.keys
|
||||
|
||||
|
||||
def _get_endpoint_functions(
|
||||
endpoint: type, prefixes: List[str]
|
||||
) -> Iterator[Tuple[str, str, str, Callable]]:
|
||||
if not inspect.isclass(endpoint):
|
||||
raise ValueError(f"object is not a class type: {endpoint}")
|
||||
|
||||
functions = inspect.getmembers(endpoint, inspect.isfunction)
|
||||
for func_name, func_ref in functions:
|
||||
webmethod = getattr(func_ref, "__webmethod__", None)
|
||||
if not webmethod:
|
||||
continue
|
||||
|
||||
print(f"Processing {colored(func_name, 'white')}...")
|
||||
operation_name = func_name
|
||||
if operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
prefix = "get"
|
||||
elif (
|
||||
operation_name.startswith("delete_")
|
||||
or operation_name.startswith("remove_")
|
||||
or operation_name.endswith("/delete")
|
||||
or operation_name.endswith("/remove")
|
||||
):
|
||||
prefix = "delete"
|
||||
else:
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
else:
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
|
||||
yield prefix, operation_name, func_name, func_ref
|
||||
|
||||
|
||||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||
|
||||
# iterate in reverse member resolution order to find most specific class first
|
||||
for cls in reversed(inspect.getmro(derived_cls)):
|
||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||
if name == member_fn:
|
||||
return cls
|
||||
|
||||
raise ValidationError(
|
||||
f"cannot find defining class for {member_fn} in {derived_cls}"
|
||||
)
|
||||
|
||||
|
||||
def get_endpoint_operations(
|
||||
endpoint: type, use_examples: bool = True
|
||||
) -> List[EndpointOperation]:
|
||||
"""
|
||||
Extracts a list of member functions in a class eligible for HTTP interface binding.
|
||||
|
||||
These member functions are expected to have a signature like
|
||||
```
|
||||
async def get_object(self, uuid: str, version: int) -> Object:
|
||||
...
|
||||
```
|
||||
where the prefix `get_` translates to an HTTP GET, `object` corresponds to the name of the endpoint operation,
|
||||
`uuid` and `version` are mapped to route path elements in "/object/{uuid}/{version}", and `Object` becomes
|
||||
the response payload type, transmitted as an object serialized to JSON.
|
||||
|
||||
If the member function has a composite class type in the argument list, it becomes the request payload type,
|
||||
and the caller is expected to provide the data as serialized JSON in an HTTP POST request.
|
||||
|
||||
:param endpoint: A class with member functions that can be mapped to an HTTP endpoint.
|
||||
:param use_examples: Whether to return examples associated with member functions.
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
for prefix, operation_name, func_name, func_ref in _get_endpoint_functions(
|
||||
endpoint,
|
||||
[
|
||||
"create",
|
||||
"delete",
|
||||
"do",
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"remove",
|
||||
"set",
|
||||
"update",
|
||||
],
|
||||
):
|
||||
# extract routing information from function metadata
|
||||
webmethod = getattr(func_ref, "__webmethod__", None)
|
||||
if webmethod is not None:
|
||||
route = webmethod.route
|
||||
route_params = _get_route_parameters(route) if route is not None else None
|
||||
public = webmethod.public
|
||||
request_examples = webmethod.request_examples
|
||||
response_examples = webmethod.response_examples
|
||||
else:
|
||||
route = None
|
||||
route_params = None
|
||||
public = False
|
||||
request_examples = None
|
||||
response_examples = None
|
||||
|
||||
# inspect function signature for path and query parameters, and request/response payload type
|
||||
signature = get_signature(func_ref)
|
||||
|
||||
path_params = []
|
||||
query_params = []
|
||||
request_params = []
|
||||
|
||||
for param_name, parameter in signature.parameters.items():
|
||||
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
||||
|
||||
# omit "self" for instance methods
|
||||
if param_name == "self" and param_type is inspect.Parameter.empty:
|
||||
continue
|
||||
|
||||
# check if all parameters have explicit type
|
||||
if parameter.annotation is inspect.Parameter.empty:
|
||||
raise ValidationError(
|
||||
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
||||
)
|
||||
|
||||
if is_type_optional(param_type):
|
||||
inner_type: type = unwrap_optional_type(param_type)
|
||||
else:
|
||||
inner_type = param_type
|
||||
|
||||
if (
|
||||
inner_type is bool
|
||||
or inner_type is int
|
||||
or inner_type is float
|
||||
or inner_type is str
|
||||
or inner_type is uuid.UUID
|
||||
or is_type_enum(inner_type)
|
||||
):
|
||||
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
||||
if route_params is not None and param_name not in route_params:
|
||||
raise ValidationError(
|
||||
f"positional parameter '{param_name}' absent from user-defined route '{route}' for function '{func_name}'"
|
||||
)
|
||||
|
||||
# simple type maps to route path element, e.g. /study/{uuid}/{version}
|
||||
path_params.append((param_name, param_type))
|
||||
else:
|
||||
if route_params is not None and param_name in route_params:
|
||||
raise ValidationError(
|
||||
f"query parameter '{param_name}' found in user-defined route '{route}' for function '{func_name}'"
|
||||
)
|
||||
|
||||
# simple type maps to key=value pair in query string
|
||||
query_params.append((param_name, param_type))
|
||||
else:
|
||||
if route_params is not None and param_name in route_params:
|
||||
raise ValidationError(
|
||||
f"user-defined route '{route}' for function '{func_name}' has parameter '{param_name}' of composite type: {param_type}"
|
||||
)
|
||||
|
||||
request_params.append((param_name, param_type))
|
||||
|
||||
# check if function has explicit return type
|
||||
if signature.return_annotation is inspect.Signature.empty:
|
||||
raise ValidationError(
|
||||
f"function '{func_name}' has no return type annotation"
|
||||
)
|
||||
|
||||
return_type = _get_annotation_type(signature.return_annotation, func_ref)
|
||||
|
||||
# operations that produce events are labeled as Generator[YieldType, SendType, ReturnType]
|
||||
# where YieldType is the event type, SendType is None, and ReturnType is the immediate response type to the request
|
||||
if typing.get_origin(return_type) is collections.abc.Generator:
|
||||
event_type, send_type, response_type = typing.get_args(return_type)
|
||||
if send_type is not type(None):
|
||||
raise ValidationError(
|
||||
f"function '{func_name}' has a return type Generator[Y,S,R] and therefore looks like an event but has an explicit send type"
|
||||
)
|
||||
else:
|
||||
event_type = None
|
||||
response_type = return_type
|
||||
|
||||
# set HTTP request method based on type of request and presence of payload
|
||||
if not request_params:
|
||||
if prefix in ["delete", "remove"]:
|
||||
http_method = HTTPMethod.DELETE
|
||||
else:
|
||||
http_method = HTTPMethod.GET
|
||||
else:
|
||||
if prefix == "set":
|
||||
http_method = HTTPMethod.PUT
|
||||
elif prefix == "update":
|
||||
http_method = HTTPMethod.PATCH
|
||||
else:
|
||||
http_method = HTTPMethod.POST
|
||||
|
||||
result.append(
|
||||
EndpointOperation(
|
||||
defining_class=_get_defining_class(func_name, endpoint),
|
||||
name=operation_name,
|
||||
func_name=func_name,
|
||||
func_ref=func_ref,
|
||||
route=route,
|
||||
path_params=path_params,
|
||||
query_params=query_params,
|
||||
request_params=request_params,
|
||||
event_type=event_type,
|
||||
response_type=response_type,
|
||||
http_method=http_method,
|
||||
public=public,
|
||||
request_examples=request_examples if use_examples else None,
|
||||
response_examples=response_examples if use_examples else None,
|
||||
)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ValidationError(f"no eligible endpoint operations in type {endpoint}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_endpoint_events(endpoint: type) -> Dict[str, type]:
|
||||
results = {}
|
||||
|
||||
for decl in typing.get_type_hints(endpoint).values():
|
||||
# check if signature is Callable[...]
|
||||
origin = typing.get_origin(decl)
|
||||
if origin is None or not issubclass(origin, Callable): # type: ignore
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[...], Any]
|
||||
args = typing.get_args(decl)
|
||||
if len(args) != 2:
|
||||
continue
|
||||
params_type, return_type = args
|
||||
if not isinstance(params_type, list):
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[...], None]
|
||||
if not issubclass(return_type, type(None)):
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[EventType], None]
|
||||
if len(params_type) != 1:
|
||||
continue
|
||||
|
||||
param_type = params_type[0]
|
||||
results[param_type.__name__] = param_type
|
||||
|
||||
return results
|
75
rfcs/openapi_generator/pyopenapi/options.py
Normal file
75
rfcs/openapi_generator/pyopenapi/options.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .specification import (
|
||||
Info,
|
||||
SecurityScheme,
|
||||
SecuritySchemeAPI,
|
||||
SecuritySchemeHTTP,
|
||||
SecuritySchemeOpenIDConnect,
|
||||
Server,
|
||||
)
|
||||
|
||||
HTTPStatusCode = Union[HTTPStatus, int, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Options:
|
||||
"""
|
||||
:param server: Base URL for the API endpoint.
|
||||
:param info: Meta-information for the endpoint specification.
|
||||
:param version: OpenAPI specification version as a tuple of major, minor, revision.
|
||||
:param default_security_scheme: Security scheme to apply to endpoints, unless overridden on a per-endpoint basis.
|
||||
:param extra_types: Extra types in addition to those found in operation signatures. Use a dictionary to group related types.
|
||||
:param use_examples: Whether to emit examples for operations.
|
||||
:param success_responses: Associates operation response types with HTTP status codes.
|
||||
:param error_responses: Associates error response types with HTTP status codes.
|
||||
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
||||
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
||||
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
||||
"""
|
||||
|
||||
server: Server
|
||||
info: Info
|
||||
version: Tuple[int, int, int] = (3, 1, 0)
|
||||
default_security_scheme: Optional[SecurityScheme] = None
|
||||
extra_types: Union[List[type], Dict[str, List[type]], None] = None
|
||||
use_examples: bool = True
|
||||
success_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
error_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
error_wrapper: bool = False
|
||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
captions: Optional[Dict[str, str]] = None
|
||||
|
||||
default_captions: ClassVar[Dict[str, str]] = {
|
||||
"Operations": "Operations",
|
||||
"Types": "Types",
|
||||
"Events": "Events",
|
||||
"AdditionalTypes": "Additional types",
|
||||
}
|
||||
|
||||
def map(self, id: str) -> str:
|
||||
"Maps a language-neutral placeholder string to language-dependent text."
|
||||
|
||||
if self.captions is not None:
|
||||
caption = self.captions.get(id)
|
||||
if caption is not None:
|
||||
return caption
|
||||
|
||||
caption = self.__class__.default_captions.get(id)
|
||||
if caption is not None:
|
||||
return caption
|
||||
|
||||
raise KeyError(f"no caption found for ID: {id}")
|
258
rfcs/openapi_generator/pyopenapi/specification.py
Normal file
258
rfcs/openapi_generator/pyopenapi/specification.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
from strong_typing.schema import JsonType, Schema, StrictJsonType
|
||||
|
||||
URL = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ref:
|
||||
ref_type: ClassVar[str]
|
||||
id: str
|
||||
|
||||
def to_json(self) -> StrictJsonType:
|
||||
return {"$ref": f"#/components/{self.ref_type}/{self.id}"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaRef(Ref):
|
||||
ref_type: ClassVar[str] = "schemas"
|
||||
|
||||
|
||||
SchemaOrRef = Union[Schema, SchemaRef]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseRef(Ref):
|
||||
ref_type: ClassVar[str] = "responses"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterRef(Ref):
|
||||
ref_type: ClassVar[str] = "parameters"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleRef(Ref):
|
||||
ref_type: ClassVar[str] = "examples"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Contact:
|
||||
name: Optional[str] = None
|
||||
url: Optional[URL] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class License:
|
||||
name: str
|
||||
url: Optional[URL] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Info:
|
||||
title: str
|
||||
version: str
|
||||
description: Optional[str] = None
|
||||
termsOfService: Optional[str] = None
|
||||
contact: Optional[Contact] = None
|
||||
license: Optional[License] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaType:
|
||||
schema: Optional[SchemaOrRef] = None
|
||||
example: Optional[Any] = None
|
||||
examples: Optional[Dict[str, Union["Example", ExampleRef]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestBody:
|
||||
content: Dict[str, MediaType]
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
description: str
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
|
||||
|
||||
class ParameterLocation(enum.Enum):
|
||||
Query = "query"
|
||||
Header = "header"
|
||||
Path = "path"
|
||||
Cookie = "cookie"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameter:
|
||||
name: str
|
||||
in_: ParameterLocation
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
schema: Optional[SchemaOrRef] = None
|
||||
example: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
responses: Dict[str, Union[Response, ResponseRef]]
|
||||
tags: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
operationId: Optional[str] = None
|
||||
parameters: Optional[List[Parameter]] = None
|
||||
requestBody: Optional[RequestBody] = None
|
||||
callbacks: Optional[Dict[str, "Callback"]] = None
|
||||
security: Optional[List["SecurityRequirement"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathItem:
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
get: Optional[Operation] = None
|
||||
put: Optional[Operation] = None
|
||||
post: Optional[Operation] = None
|
||||
delete: Optional[Operation] = None
|
||||
options: Optional[Operation] = None
|
||||
head: Optional[Operation] = None
|
||||
patch: Optional[Operation] = None
|
||||
trace: Optional[Operation] = None
|
||||
|
||||
def update(self, other: "PathItem") -> None:
|
||||
"Merges another instance of this class into this object."
|
||||
|
||||
for field in dataclasses.fields(self.__class__):
|
||||
value = getattr(other, field.name)
|
||||
if value is not None:
|
||||
setattr(self, field.name, value)
|
||||
|
||||
|
||||
# maps run-time expressions such as "$request.body#/url" to path items
|
||||
Callback = Dict[str, PathItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Example:
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
value: Optional[Any] = None
|
||||
externalValue: Optional[URL] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Server:
|
||||
url: URL
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeType(enum.Enum):
|
||||
ApiKey = "apiKey"
|
||||
HTTP = "http"
|
||||
OAuth2 = "oauth2"
|
||||
OpenIDConnect = "openIdConnect"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityScheme:
|
||||
type: SecuritySchemeType
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeAPI(SecurityScheme):
|
||||
name: str
|
||||
in_: ParameterLocation
|
||||
|
||||
def __init__(self, description: str, name: str, in_: ParameterLocation) -> None:
|
||||
super().__init__(SecuritySchemeType.ApiKey, description)
|
||||
self.name = name
|
||||
self.in_ = in_
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeHTTP(SecurityScheme):
|
||||
scheme: str
|
||||
bearerFormat: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self, description: str, scheme: str, bearerFormat: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(SecuritySchemeType.HTTP, description)
|
||||
self.scheme = scheme
|
||||
self.bearerFormat = bearerFormat
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeOpenIDConnect(SecurityScheme):
|
||||
openIdConnectUrl: str
|
||||
|
||||
def __init__(self, description: str, openIdConnectUrl: str) -> None:
|
||||
super().__init__(SecuritySchemeType.OpenIDConnect, description)
|
||||
self.openIdConnectUrl = openIdConnectUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class Components:
|
||||
schemas: Optional[Dict[str, Schema]] = None
|
||||
responses: Optional[Dict[str, Response]] = None
|
||||
parameters: Optional[Dict[str, Parameter]] = None
|
||||
examples: Optional[Dict[str, Example]] = None
|
||||
requestBodies: Optional[Dict[str, RequestBody]] = None
|
||||
securitySchemes: Optional[Dict[str, SecurityScheme]] = None
|
||||
callbacks: Optional[Dict[str, Callback]] = None
|
||||
|
||||
|
||||
SecurityScope = str
|
||||
SecurityRequirement = Dict[str, List[SecurityScope]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tag:
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
displayName: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TagGroup:
|
||||
"""
|
||||
A ReDoc extension to provide information about groups of tags.
|
||||
|
||||
Exposed via the vendor-specific property "x-tagGroups" of the top-level object.
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""
|
||||
This class is a Python dataclass adaptation of the OpenAPI Specification.
|
||||
|
||||
For details, see <https://swagger.io/specification/>
|
||||
"""
|
||||
|
||||
openapi: str
|
||||
info: Info
|
||||
servers: List[Server]
|
||||
paths: Dict[str, PathItem]
|
||||
jsonSchemaDialect: Optional[str] = None
|
||||
components: Optional[Components] = None
|
||||
security: Optional[List[SecurityRequirement]] = None
|
||||
tags: Optional[List[Tag]] = None
|
||||
tagGroups: Optional[List[TagGroup]] = None
|
41
rfcs/openapi_generator/pyopenapi/template.html
Normal file
41
rfcs/openapi_generator/pyopenapi/template.html
Normal file
|
@ -0,0 +1,41 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>OpenAPI specification</title>
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
</style>
|
||||
<script defer="defer" src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
|
||||
<script defer="defer">
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
spec = { /* OPENAPI_SPECIFICATION */ };
|
||||
options = {
|
||||
downloadFileName: "openapi.json",
|
||||
expandResponses: "200",
|
||||
expandSingleSchemaField: true,
|
||||
jsonSampleExpandLevel: "all",
|
||||
schemaExpansionLevel: "all",
|
||||
};
|
||||
element = document.getElementById("openapi-container");
|
||||
Redoc.init(spec, options, element);
|
||||
|
||||
if (spec.info && spec.info.title) {
|
||||
document.title = spec.info.title;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="openapi-container"></div>
|
||||
</body>
|
||||
|
||||
</html>
|
116
rfcs/openapi_generator/pyopenapi/utility.py
Normal file
116
rfcs/openapi_generator/pyopenapi/utility.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
from strong_typing.schema import object_to_json, StrictJsonType
|
||||
|
||||
from .generator import Generator
|
||||
from .options import Options
|
||||
from .specification import Document
|
||||
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class Specification:
|
||||
document: Document
|
||||
|
||||
def __init__(self, endpoint: type, options: Options):
|
||||
generator = Generator(endpoint, options)
|
||||
self.document = generator.generate()
|
||||
|
||||
def get_json(self) -> StrictJsonType:
|
||||
"""
|
||||
Returns the OpenAPI specification as a Python data type (e.g. `dict` for an object, `list` for an array).
|
||||
|
||||
The result can be serialized to a JSON string with `json.dump` or `json.dumps`.
|
||||
"""
|
||||
|
||||
json_doc = typing.cast(StrictJsonType, object_to_json(self.document))
|
||||
|
||||
if isinstance(json_doc, dict):
|
||||
# rename vendor-specific properties
|
||||
tag_groups = json_doc.pop("tagGroups", None)
|
||||
if tag_groups:
|
||||
json_doc["x-tagGroups"] = tag_groups
|
||||
tags = json_doc.get("tags")
|
||||
if tags and isinstance(tags, list):
|
||||
for tag in tags:
|
||||
if not isinstance(tag, dict):
|
||||
continue
|
||||
|
||||
display_name = tag.pop("displayName", None)
|
||||
if display_name:
|
||||
tag["x-displayName"] = display_name
|
||||
|
||||
return json_doc
|
||||
|
||||
def get_json_string(self, pretty_print: bool = False) -> str:
|
||||
"""
|
||||
Returns the OpenAPI specification as a JSON string.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the output.
|
||||
"""
|
||||
|
||||
json_doc = self.get_json()
|
||||
if pretty_print:
|
||||
return json.dumps(
|
||||
json_doc, check_circular=False, ensure_ascii=False, indent=4
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
json_doc,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
def write_json(self, f: TextIO, pretty_print: bool = False) -> None:
|
||||
"""
|
||||
Writes the OpenAPI specification to a file as a JSON string.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the output.
|
||||
"""
|
||||
|
||||
json_doc = self.get_json()
|
||||
if pretty_print:
|
||||
json.dump(
|
||||
json_doc,
|
||||
f,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
else:
|
||||
json.dump(
|
||||
json_doc,
|
||||
f,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
def write_html(self, f: TextIO, pretty_print: bool = False) -> None:
|
||||
"""
|
||||
Creates a stand-alone HTML page for the OpenAPI specification with ReDoc.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the JSON string in the HTML file.
|
||||
"""
|
||||
|
||||
path = THIS_DIR / "template.html"
|
||||
with path.open(encoding="utf-8", errors="strict") as html_template_file:
|
||||
html_template = html_template_file.read()
|
||||
|
||||
html = html_template.replace(
|
||||
"{ /* OPENAPI_SPECIFICATION */ }",
|
||||
self.get_json_string(pretty_print=pretty_print),
|
||||
)
|
||||
|
||||
f.write(html)
|
|
@ -1,6 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -14,12 +13,11 @@ set -euo pipefail
|
|||
missing_packages=()
|
||||
|
||||
check_package() {
|
||||
if ! pip show "$1" &> /dev/null; then
|
||||
if ! pip show "$1" &>/dev/null; then
|
||||
missing_packages+=("$1")
|
||||
fi
|
||||
}
|
||||
|
||||
check_package python-openapi
|
||||
check_package json-strong-typing
|
||||
|
||||
if [ ${#missing_packages[@]} -ne 0 ]; then
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue