mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	# What does this PR do? * Adds the OpenAI compatible Files API * Modified doc gen script to support multipart parameter ## Test Plan --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/2321). * #2330 * __->__ #2321
		
			
				
	
	
		
			938 lines
		
	
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			938 lines
		
	
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import hashlib
 | |
| import ipaddress
 | |
| import types
 | |
| import typing
 | |
| from dataclasses import make_dataclass
 | |
| from typing import Any, Dict, Set, Union
 | |
| 
 | |
| from llama_stack.apis.datatypes import Error
 | |
| from llama_stack.strong_typing.core import JsonType
 | |
| from llama_stack.strong_typing.docstring import Docstring, parse_type
 | |
| from llama_stack.strong_typing.inspection import (
 | |
|     is_generic_list,
 | |
|     is_type_optional,
 | |
|     is_type_union,
 | |
|     unwrap_generic_list,
 | |
|     unwrap_optional_type,
 | |
|     unwrap_union_types,
 | |
| )
 | |
| from llama_stack.strong_typing.name import python_type_to_name
 | |
| from llama_stack.strong_typing.schema import (
 | |
|     get_schema_identifier,
 | |
|     JsonSchemaGenerator,
 | |
|     register_schema,
 | |
|     Schema,
 | |
|     SchemaOptions,
 | |
| )
 | |
| from typing import get_origin, get_args
 | |
| from typing import Annotated
 | |
| from fastapi import UploadFile
 | |
| from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
 | |
| 
 | |
| from .operations import (
 | |
|     EndpointOperation,
 | |
|     get_endpoint_events,
 | |
|     get_endpoint_operations,
 | |
|     HTTPMethod,
 | |
| )
 | |
| from .options import *
 | |
| from .specification import (
 | |
|     Components,
 | |
|     Document,
 | |
|     Example,
 | |
|     ExampleRef,
 | |
|     MediaType,
 | |
|     Operation,
 | |
|     Parameter,
 | |
|     ParameterLocation,
 | |
|     PathItem,
 | |
|     RequestBody,
 | |
|     Response,
 | |
|     ResponseRef,
 | |
|     SchemaOrRef,
 | |
|     SchemaRef,
 | |
|     Tag,
 | |
|     TagGroup,
 | |
| )
 | |
| 
 | |
| register_schema(
 | |
|     ipaddress.IPv4Address,
 | |
|     schema={
 | |
|         "type": "string",
 | |
|         "format": "ipv4",
 | |
|         "title": "IPv4 address",
 | |
|         "description": "IPv4 address, according to dotted-quad ABNF syntax as defined in RFC 2673, section 3.2.",
 | |
|     },
 | |
|     examples=["192.0.2.0", "198.51.100.1", "203.0.113.255"],
 | |
| )
 | |
| 
 | |
| register_schema(
 | |
|     ipaddress.IPv6Address,
 | |
|     schema={
 | |
|         "type": "string",
 | |
|         "format": "ipv6",
 | |
|         "title": "IPv6 address",
 | |
|         "description": "IPv6 address, as defined in RFC 2373, section 2.2.",
 | |
|     },
 | |
|     examples=[
 | |
|         "FEDC:BA98:7654:3210:FEDC:BA98:7654:3210",
 | |
|         "1080:0:0:0:8:800:200C:417A",
 | |
|         "1080::8:800:200C:417A",
 | |
|         "FF01::101",
 | |
|         "::1",
 | |
|     ],
 | |
| )
 | |
| 
 | |
| 
 | |
| def http_status_to_string(status_code: HTTPStatusCode) -> str:
 | |
|     "Converts an HTTP status code to a string."
 | |
| 
 | |
|     if isinstance(status_code, HTTPStatus):
 | |
|         return str(status_code.value)
 | |
|     elif isinstance(status_code, int):
 | |
|         return str(status_code)
 | |
|     elif isinstance(status_code, str):
 | |
|         return status_code
 | |
|     else:
 | |
|         raise TypeError("expected: HTTP status code")
 | |
| 
 | |
| 
 | |
| class SchemaBuilder:
 | |
|     schema_generator: JsonSchemaGenerator
 | |
|     schemas: Dict[str, Schema]
 | |
| 
 | |
|     def __init__(self, schema_generator: JsonSchemaGenerator) -> None:
 | |
|         self.schema_generator = schema_generator
 | |
|         self.schemas = {}
 | |
| 
 | |
|     def classdef_to_schema(self, typ: type) -> Schema:
 | |
|         """
 | |
|         Converts a type to a JSON schema.
 | |
|         For nested types found in the type hierarchy, adds the type to the schema registry in the OpenAPI specification section `components`.
 | |
|         """
 | |
| 
 | |
|         type_schema, type_definitions = self.schema_generator.classdef_to_schema(typ)
 | |
| 
 | |
|         # append schema to list of known schemas, to be used in OpenAPI's Components Object section
 | |
|         for ref, schema in type_definitions.items():
 | |
|             self._add_ref(ref, schema)
 | |
| 
 | |
|         return type_schema
 | |
| 
 | |
|     def classdef_to_named_schema(self, name: str, typ: type) -> Schema:
 | |
|         schema = self.classdef_to_schema(typ)
 | |
|         self._add_ref(name, schema)
 | |
|         return schema
 | |
| 
 | |
|     def classdef_to_ref(self, typ: type) -> SchemaOrRef:
 | |
|         """
 | |
|         Converts a type to a JSON schema, and if possible, returns a schema reference.
 | |
|         For composite types (such as classes), adds the type to the schema registry in the OpenAPI specification section `components`.
 | |
|         """
 | |
| 
 | |
|         type_schema = self.classdef_to_schema(typ)
 | |
|         if typ is str or typ is int or typ is float:
 | |
|             # represent simple types as themselves
 | |
|             return type_schema
 | |
| 
 | |
|         type_name = get_schema_identifier(typ)
 | |
|         if type_name is not None:
 | |
|             return self._build_ref(type_name, type_schema)
 | |
| 
 | |
|         try:
 | |
|             type_name = python_type_to_name(typ)
 | |
|             return self._build_ref(type_name, type_schema)
 | |
|         except TypeError:
 | |
|             pass
 | |
| 
 | |
|         return type_schema
 | |
| 
 | |
|     def _build_ref(self, type_name: str, type_schema: Schema) -> SchemaRef:
 | |
|         self._add_ref(type_name, type_schema)
 | |
|         return SchemaRef(type_name)
 | |
| 
 | |
|     def _add_ref(self, type_name: str, type_schema: Schema) -> None:
 | |
|         if type_name not in self.schemas:
 | |
|             self.schemas[type_name] = type_schema
 | |
| 
 | |
| 
 | |
| class ContentBuilder:
 | |
|     schema_builder: SchemaBuilder
 | |
|     schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]]
 | |
|     sample_transformer: Optional[Callable[[JsonType], JsonType]]
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         schema_builder: SchemaBuilder,
 | |
|         schema_transformer: Optional[Callable[[SchemaOrRef], SchemaOrRef]] = None,
 | |
|         sample_transformer: Optional[Callable[[JsonType], JsonType]] = None,
 | |
|     ) -> None:
 | |
|         self.schema_builder = schema_builder
 | |
|         self.schema_transformer = schema_transformer
 | |
|         self.sample_transformer = sample_transformer
 | |
| 
 | |
|     def build_content(
 | |
|         self, payload_type: type, examples: Optional[List[Any]] = None
 | |
|     ) -> Dict[str, MediaType]:
 | |
|         "Creates the content subtree for a request or response."
 | |
| 
 | |
|         def is_iterator_type(t):
 | |
|             return "StreamChunk" in str(t) or "OpenAIResponseObjectStream" in str(t)
 | |
| 
 | |
|         def get_media_type(t):
 | |
|             if is_generic_list(t):
 | |
|                 return "application/jsonl"
 | |
|             elif is_iterator_type(t):
 | |
|                 return "text/event-stream"
 | |
|             else:
 | |
|                 return "application/json"
 | |
| 
 | |
|         if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
 | |
|             media_types = []
 | |
|             item_types = []
 | |
|             for x in typing.get_args(payload_type):
 | |
|                 media_types.append(get_media_type(x))
 | |
|                 item_types.append(x)
 | |
| 
 | |
|             if len(set(media_types)) == 1:
 | |
|                 # all types have the same media type
 | |
|                 return {media_types[0]: self.build_media_type(payload_type, examples)}
 | |
|             else:
 | |
|                 # different types have different media types
 | |
|                 return {
 | |
|                     media_type: self.build_media_type(item_type, examples)
 | |
|                     for media_type, item_type in zip(media_types, item_types)
 | |
|                 }
 | |
| 
 | |
|         if is_generic_list(payload_type):
 | |
|             media_type = "application/jsonl"
 | |
|             item_type = unwrap_generic_list(payload_type)
 | |
|         else:
 | |
|             media_type = "application/json"
 | |
|             item_type = payload_type
 | |
| 
 | |
|         return {media_type: self.build_media_type(item_type, examples)}
 | |
| 
 | |
|     def build_media_type(
 | |
|         self, item_type: type, examples: Optional[List[Any]] = None
 | |
|     ) -> MediaType:
 | |
|         schema = self.schema_builder.classdef_to_ref(item_type)
 | |
|         if self.schema_transformer:
 | |
|             schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = (
 | |
|                 self.schema_transformer
 | |
|             )
 | |
|             schema = schema_transformer(schema)
 | |
| 
 | |
|         if not examples:
 | |
|             return MediaType(schema=schema)
 | |
| 
 | |
|         if len(examples) == 1:
 | |
|             return MediaType(schema=schema, example=self._build_example(examples[0]))
 | |
| 
 | |
|         return MediaType(
 | |
|             schema=schema,
 | |
|             examples=self._build_examples(examples),
 | |
|         )
 | |
| 
 | |
|     def _build_examples(
 | |
|         self, examples: List[Any]
 | |
|     ) -> Dict[str, Union[Example, ExampleRef]]:
 | |
|         "Creates a set of several examples for a media type."
 | |
| 
 | |
|         if self.sample_transformer:
 | |
|             sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer  # type: ignore
 | |
|         else:
 | |
|             sample_transformer = lambda sample: sample
 | |
| 
 | |
|         results: Dict[str, Union[Example, ExampleRef]] = {}
 | |
|         for example in examples:
 | |
|             value = sample_transformer(object_to_json(example))
 | |
| 
 | |
|             hash_string = (
 | |
|                 hashlib.sha256(json_dump_string(value).encode("utf-8"))
 | |
|                 .digest()
 | |
|                 .hex()[:16]
 | |
|             )
 | |
|             name = f"ex-{hash_string}"
 | |
| 
 | |
|             results[name] = Example(value=value)
 | |
| 
 | |
|         return results
 | |
| 
 | |
|     def _build_example(self, example: Any) -> Any:
 | |
|         "Creates a single example for a media type."
 | |
| 
 | |
|         if self.sample_transformer:
 | |
|             sample_transformer: Callable[[JsonType], JsonType] = self.sample_transformer  # type: ignore
 | |
|         else:
 | |
|             sample_transformer = lambda sample: sample
 | |
| 
 | |
|         return sample_transformer(object_to_json(example))
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class ResponseOptions:
 | |
|     """
 | |
|     Configuration options for building a response for an operation.
 | |
| 
 | |
|     :param type_descriptions: Maps each response type to a textual description (if available).
 | |
|     :param examples: A list of response examples.
 | |
|     :param status_catalog: Maps each response type to an HTTP status code.
 | |
|     :param default_status_code: HTTP status code assigned to responses that have no mapping.
 | |
|     """
 | |
| 
 | |
|     type_descriptions: Dict[type, str]
 | |
|     examples: Optional[List[Any]]
 | |
|     status_catalog: Dict[type, HTTPStatusCode]
 | |
|     default_status_code: HTTPStatusCode
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class StatusResponse:
 | |
|     status_code: str
 | |
|     types: List[type] = dataclasses.field(default_factory=list)
 | |
|     examples: List[Any] = dataclasses.field(default_factory=list)
 | |
| 
 | |
| 
 | |
| def create_docstring_for_request(
 | |
|     request_name: str, fields: List[Tuple[str, type, Any]], doc_params: Dict[str, str]
 | |
| ) -> str:
 | |
|     """Creates a ReST-style docstring for a dynamically generated request dataclass."""
 | |
|     lines = ["\n"]  # Short description
 | |
| 
 | |
|     # Add parameter documentation in ReST format
 | |
|     for name, type_ in fields:
 | |
|         desc = doc_params.get(name, "")
 | |
|         lines.append(f":param {name}: {desc}")
 | |
| 
 | |
|     return "\n".join(lines)
 | |
| 
 | |
| 
 | |
| class ResponseBuilder:
 | |
|     content_builder: ContentBuilder
 | |
| 
 | |
|     def __init__(self, content_builder: ContentBuilder) -> None:
 | |
|         self.content_builder = content_builder
 | |
| 
 | |
|     def _get_status_responses(
 | |
|         self, options: ResponseOptions
 | |
|     ) -> Dict[str, StatusResponse]:
 | |
|         status_responses: Dict[str, StatusResponse] = {}
 | |
| 
 | |
|         for response_type in options.type_descriptions.keys():
 | |
|             status_code = http_status_to_string(
 | |
|                 options.status_catalog.get(response_type, options.default_status_code)
 | |
|             )
 | |
| 
 | |
|             # look up response for status code
 | |
|             if status_code not in status_responses:
 | |
|                 status_responses[status_code] = StatusResponse(status_code)
 | |
|             status_response = status_responses[status_code]
 | |
| 
 | |
|             # append response types that are assigned the given status code
 | |
|             status_response.types.append(response_type)
 | |
| 
 | |
|             # append examples that have the matching response type
 | |
|             if options.examples:
 | |
|                 status_response.examples.extend(
 | |
|                     example
 | |
|                     for example in options.examples
 | |
|                     if isinstance(example, response_type)
 | |
|                 )
 | |
| 
 | |
|         return dict(sorted(status_responses.items()))
 | |
| 
 | |
|     def build_response(
 | |
|         self, options: ResponseOptions
 | |
|     ) -> Dict[str, Union[Response, ResponseRef]]:
 | |
|         """
 | |
|         Groups responses that have the same status code.
 | |
|         """
 | |
| 
 | |
|         responses: Dict[str, Union[Response, ResponseRef]] = {}
 | |
|         status_responses = self._get_status_responses(options)
 | |
|         for status_code, status_response in status_responses.items():
 | |
|             response_types = tuple(status_response.types)
 | |
|             if len(response_types) > 1:
 | |
|                 composite_response_type: type = Union[response_types]  # type: ignore
 | |
|             else:
 | |
|                 (response_type,) = response_types
 | |
|                 composite_response_type = response_type
 | |
| 
 | |
|             description = " **OR** ".join(
 | |
|                 filter(
 | |
|                     None,
 | |
|                     (
 | |
|                         options.type_descriptions[response_type]
 | |
|                         for response_type in response_types
 | |
|                     ),
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             responses[status_code] = self._build_response(
 | |
|                 response_type=composite_response_type,
 | |
|                 description=description,
 | |
|                 examples=status_response.examples or None,
 | |
|             )
 | |
| 
 | |
|         return responses
 | |
| 
 | |
|     def _build_response(
 | |
|         self,
 | |
|         response_type: type,
 | |
|         description: str,
 | |
|         examples: Optional[List[Any]] = None,
 | |
|     ) -> Response:
 | |
|         "Creates a response subtree."
 | |
| 
 | |
|         if response_type is not None:
 | |
|             return Response(
 | |
|                 description=description,
 | |
|                 content=self.content_builder.build_content(response_type, examples),
 | |
|             )
 | |
|         else:
 | |
|             return Response(description=description)
 | |
| 
 | |
| 
 | |
| def schema_error_wrapper(schema: SchemaOrRef) -> Schema:
 | |
|     "Wraps an error output schema into a top-level error schema."
 | |
| 
 | |
|     return {
 | |
|         "type": "object",
 | |
|         "properties": {
 | |
|             "error": schema,  # type: ignore
 | |
|         },
 | |
|         "additionalProperties": False,
 | |
|         "required": [
 | |
|             "error",
 | |
|         ],
 | |
|     }
 | |
| 
 | |
| 
 | |
| def sample_error_wrapper(error: JsonType) -> JsonType:
 | |
|     "Wraps an error output sample into a top-level error sample."
 | |
| 
 | |
|     return {"error": error}
 | |
| 
 | |
| 
 | |
| class Generator:
 | |
|     endpoint: type
 | |
|     options: Options
 | |
|     schema_builder: SchemaBuilder
 | |
|     responses: Dict[str, Response]
 | |
| 
 | |
|     def __init__(self, endpoint: type, options: Options) -> None:
 | |
|         self.endpoint = endpoint
 | |
|         self.options = options
 | |
|         schema_generator = JsonSchemaGenerator(
 | |
|             SchemaOptions(
 | |
|                 definitions_path="#/components/schemas/",
 | |
|                 use_examples=self.options.use_examples,
 | |
|                 property_description_fun=options.property_description_fun,
 | |
|             )
 | |
|         )
 | |
|         self.schema_builder = SchemaBuilder(schema_generator)
 | |
|         self.responses = {}
 | |
| 
 | |
|         # Create standard error responses
 | |
|         self._create_standard_error_responses()
 | |
| 
 | |
|     def _create_standard_error_responses(self) -> None:
 | |
|         """
 | |
|         Creates standard error responses that can be reused across operations.
 | |
|         These will be added to the components.responses section of the OpenAPI document.
 | |
|         """
 | |
|         # Get the Error schema
 | |
|         error_schema = self.schema_builder.classdef_to_ref(Error)
 | |
| 
 | |
|         # Create standard error responses
 | |
|         self.responses["BadRequest400"] = Response(
 | |
|             description="The request was invalid or malformed",
 | |
|             content={
 | |
|                 "application/json": MediaType(
 | |
|                     schema=error_schema,
 | |
|                     example={
 | |
|                         "status": 400,
 | |
|                         "title": "Bad Request",
 | |
|                         "detail": "The request was invalid or malformed",
 | |
|                     },
 | |
|                 )
 | |
|             },
 | |
|         )
 | |
| 
 | |
|         self.responses["TooManyRequests429"] = Response(
 | |
|             description="The client has sent too many requests in a given amount of time",
 | |
|             content={
 | |
|                 "application/json": MediaType(
 | |
|                     schema=error_schema,
 | |
|                     example={
 | |
|                         "status": 429,
 | |
|                         "title": "Too Many Requests",
 | |
|                         "detail": "You have exceeded the rate limit. Please try again later.",
 | |
|                     },
 | |
|                 )
 | |
|             },
 | |
|         )
 | |
| 
 | |
|         self.responses["InternalServerError500"] = Response(
 | |
|             description="The server encountered an unexpected error",
 | |
|             content={
 | |
|                 "application/json": MediaType(
 | |
|                     schema=error_schema,
 | |
|                     example={
 | |
|                         "status": 500,
 | |
|                         "title": "Internal Server Error",
 | |
|                         "detail": "An unexpected error occurred. Our team has been notified.",
 | |
|                     },
 | |
|                 )
 | |
|             },
 | |
|         )
 | |
| 
 | |
|         # Add a default error response for any unhandled error cases
 | |
|         self.responses["DefaultError"] = Response(
 | |
|             description="An unexpected error occurred",
 | |
|             content={
 | |
|                 "application/json": MediaType(
 | |
|                     schema=error_schema,
 | |
|                     example={
 | |
|                         "status": 0,
 | |
|                         "title": "Error",
 | |
|                         "detail": "An unexpected error occurred",
 | |
|                     },
 | |
|                 )
 | |
|             },
 | |
|         )
 | |
| 
 | |
|     def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
 | |
|         # Don't include schema definition in the tag description because for one,
 | |
|         # it is not very valuable and for another, it causes string formatting
 | |
|         # discrepancies via the Stainless Studio.
 | |
|         #
 | |
|         # definition = f'<SchemaDefinition schemaRef="#/components/schemas/{ref}" />'
 | |
|         title = typing.cast(str, schema.get("title"))
 | |
|         description = typing.cast(str, schema.get("description"))
 | |
|         return Tag(
 | |
|             name=ref,
 | |
|             description="\n\n".join(s for s in (title, description) if s is not None),
 | |
|         )
 | |
| 
 | |
|     def _build_extra_tag_groups(
 | |
|         self, extra_types: Dict[str, Dict[str, type]]
 | |
|     ) -> Dict[str, List[Tag]]:
 | |
|         """
 | |
|         Creates a dictionary of tag group captions as keys, and tag lists as values.
 | |
| 
 | |
|         :param extra_types: A dictionary of type categories and list of types in that category.
 | |
|         """
 | |
| 
 | |
|         extra_tags: Dict[str, List[Tag]] = {}
 | |
| 
 | |
|         for category_name, category_items in extra_types.items():
 | |
|             tag_list: List[Tag] = []
 | |
| 
 | |
|             for name, extra_type in category_items.items():
 | |
|                 schema = self.schema_builder.classdef_to_schema(extra_type)
 | |
|                 tag_list.append(self._build_type_tag(name, schema))
 | |
| 
 | |
|             if tag_list:
 | |
|                 extra_tags[category_name] = tag_list
 | |
| 
 | |
|         return extra_tags
 | |
| 
 | |
|     def _build_operation(self, op: EndpointOperation) -> Operation:
 | |
|         if op.defining_class.__name__ in [
 | |
|             "SyntheticDataGeneration",
 | |
|             "PostTraining",
 | |
|             "BatchInference",
 | |
|         ]:
 | |
|             op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
 | |
|             print(op.defining_class.__name__)
 | |
| 
 | |
|         # TODO (xiyan): temporary fix for datasetio inner impl + datasets api
 | |
|         # if op.defining_class.__name__ in ["DatasetIO"]:
 | |
|         #     op.defining_class.__name__ = "Datasets"
 | |
| 
 | |
|         doc_string = parse_type(op.func_ref)
 | |
|         doc_params = dict(
 | |
|             (param.name, param.description) for param in doc_string.params.values()
 | |
|         )
 | |
| 
 | |
|         # parameters passed in URL component path
 | |
|         path_parameters = [
 | |
|             Parameter(
 | |
|                 name=param_name,
 | |
|                 in_=ParameterLocation.Path,
 | |
|                 description=doc_params.get(param_name),
 | |
|                 required=True,
 | |
|                 schema=self.schema_builder.classdef_to_ref(param_type),
 | |
|             )
 | |
|             for param_name, param_type in op.path_params
 | |
|         ]
 | |
| 
 | |
|         # parameters passed in URL component query string
 | |
|         query_parameters = []
 | |
|         for param_name, param_type in op.query_params:
 | |
|             if is_type_optional(param_type):
 | |
|                 inner_type: type = unwrap_optional_type(param_type)
 | |
|                 required = False
 | |
|             else:
 | |
|                 inner_type = param_type
 | |
|                 required = True
 | |
| 
 | |
|             query_parameter = Parameter(
 | |
|                 name=param_name,
 | |
|                 in_=ParameterLocation.Query,
 | |
|                 description=doc_params.get(param_name),
 | |
|                 required=required,
 | |
|                 schema=self.schema_builder.classdef_to_ref(inner_type),
 | |
|             )
 | |
|             query_parameters.append(query_parameter)
 | |
| 
 | |
|         # parameters passed anywhere
 | |
|         parameters = path_parameters + query_parameters
 | |
| 
 | |
|         webmethod = getattr(op.func_ref, "__webmethod__", None)
 | |
|         raw_bytes_request_body = False
 | |
|         if webmethod:
 | |
|             raw_bytes_request_body = getattr(webmethod, "raw_bytes_request_body", False)
 | |
| 
 | |
|         # data passed in request body as raw bytes cannot have request parameters
 | |
|         if raw_bytes_request_body and op.request_params:
 | |
|             raise ValueError(
 | |
|                 "Cannot have both raw bytes request body and request parameters"
 | |
|             )
 | |
| 
 | |
|         # data passed in request body as raw bytes
 | |
|         if raw_bytes_request_body:
 | |
|             requestBody = RequestBody(
 | |
|                 content={
 | |
|                     "application/octet-stream": {
 | |
|                         "schema": {
 | |
|                             "type": "string",
 | |
|                             "format": "binary",
 | |
|                         }
 | |
|                     }
 | |
|                 },
 | |
|                 required=True,
 | |
|             )
 | |
|         # data passed in request body as multipart/form-data
 | |
|         elif op.multipart_params:
 | |
|             builder = ContentBuilder(self.schema_builder)
 | |
|             
 | |
|             # Create schema properties for multipart form fields
 | |
|             properties = {}
 | |
|             required_fields = []
 | |
|             
 | |
|             for name, param_type in op.multipart_params:
 | |
|                 if get_origin(param_type) is Annotated:
 | |
|                     base_type = get_args(param_type)[0]
 | |
|                 else:
 | |
|                     base_type = param_type
 | |
|                 if base_type is UploadFile:
 | |
|                     # File upload
 | |
|                     properties[name] = {
 | |
|                         "type": "string",
 | |
|                         "format": "binary"
 | |
|                     }
 | |
|                 else:
 | |
|                     # Form field
 | |
|                     properties[name] = self.schema_builder.classdef_to_ref(base_type)
 | |
|                 
 | |
|                 required_fields.append(name)
 | |
|             
 | |
|             multipart_schema = {
 | |
|                 "type": "object",
 | |
|                 "properties": properties,
 | |
|                 "required": required_fields
 | |
|             }
 | |
|             
 | |
|             requestBody = RequestBody(
 | |
|                 content={
 | |
|                     "multipart/form-data": {
 | |
|                         "schema": multipart_schema
 | |
|                     }
 | |
|                 },
 | |
|                 required=True,
 | |
|             )
 | |
|         # data passed in payload as JSON and mapped to request parameters
 | |
|         elif op.request_params:
 | |
|             builder = ContentBuilder(self.schema_builder)
 | |
|             first = next(iter(op.request_params))
 | |
|             request_name, request_type = first
 | |
| 
 | |
|             op_name = "".join(word.capitalize() for word in op.name.split("_"))
 | |
|             request_name = f"{op_name}Request"
 | |
|             fields = [
 | |
|                 (
 | |
|                     name,
 | |
|                     type_,
 | |
|                 )
 | |
|                 for name, type_ in op.request_params
 | |
|             ]
 | |
|             request_type = make_dataclass(
 | |
|                 request_name,
 | |
|                 fields,
 | |
|                 namespace={
 | |
|                     "__doc__": create_docstring_for_request(
 | |
|                         request_name, fields, doc_params
 | |
|                     )
 | |
|                 },
 | |
|             )
 | |
| 
 | |
|             requestBody = RequestBody(
 | |
|                 content={
 | |
|                     "application/json": builder.build_media_type(
 | |
|                         request_type, op.request_examples
 | |
|                     )
 | |
|                 },
 | |
|                 description=doc_params.get(request_name),
 | |
|                 required=True,
 | |
|             )
 | |
|         else:
 | |
|             requestBody = None
 | |
| 
 | |
|         # success response types
 | |
|         if doc_string.returns is None and is_type_union(op.response_type):
 | |
|             # split union of return types into a list of response types
 | |
|             success_type_docstring: Dict[type, Docstring] = {
 | |
|                 typing.cast(type, item): parse_type(item)
 | |
|                 for item in unwrap_union_types(op.response_type)
 | |
|             }
 | |
|             success_type_descriptions = {
 | |
|                 item: doc_string.short_description
 | |
|                 for item, doc_string in success_type_docstring.items()
 | |
|             }
 | |
|         else:
 | |
|             # use return type as a single response type
 | |
|             success_type_descriptions = {
 | |
|                 op.response_type: (
 | |
|                     doc_string.returns.description if doc_string.returns else "OK"
 | |
|                 )
 | |
|             }
 | |
| 
 | |
|         response_examples = op.response_examples or []
 | |
|         success_examples = [
 | |
|             example
 | |
|             for example in response_examples
 | |
|             if not isinstance(example, Exception)
 | |
|         ]
 | |
| 
 | |
|         content_builder = ContentBuilder(self.schema_builder)
 | |
|         response_builder = ResponseBuilder(content_builder)
 | |
|         response_options = ResponseOptions(
 | |
|             success_type_descriptions,
 | |
|             success_examples if self.options.use_examples else None,
 | |
|             self.options.success_responses,
 | |
|             "200",
 | |
|         )
 | |
|         responses = response_builder.build_response(response_options)
 | |
| 
 | |
|         # failure response types
 | |
|         if doc_string.raises:
 | |
|             exception_types: Dict[type, str] = {
 | |
|                 item.raise_type: item.description for item in doc_string.raises.values()
 | |
|             }
 | |
|             exception_examples = [
 | |
|                 example
 | |
|                 for example in response_examples
 | |
|                 if isinstance(example, Exception)
 | |
|             ]
 | |
| 
 | |
|             if self.options.error_wrapper:
 | |
|                 schema_transformer = schema_error_wrapper
 | |
|                 sample_transformer = sample_error_wrapper
 | |
|             else:
 | |
|                 schema_transformer = None
 | |
|                 sample_transformer = None
 | |
| 
 | |
|             content_builder = ContentBuilder(
 | |
|                 self.schema_builder,
 | |
|                 schema_transformer=schema_transformer,
 | |
|                 sample_transformer=sample_transformer,
 | |
|             )
 | |
|             response_builder = ResponseBuilder(content_builder)
 | |
|             response_options = ResponseOptions(
 | |
|                 exception_types,
 | |
|                 exception_examples if self.options.use_examples else None,
 | |
|                 self.options.error_responses,
 | |
|                 "500",
 | |
|             )
 | |
|             responses.update(response_builder.build_response(response_options))
 | |
| 
 | |
|         assert len(responses.keys()) > 0, f"No responses found for {op.name}"
 | |
| 
 | |
|         # Add standard error response references
 | |
|         if self.options.include_standard_error_responses:
 | |
|             if "400" not in responses:
 | |
|                 responses["400"] = ResponseRef("BadRequest400")
 | |
|             if "429" not in responses:
 | |
|                 responses["429"] = ResponseRef("TooManyRequests429")
 | |
|             if "500" not in responses:
 | |
|                 responses["500"] = ResponseRef("InternalServerError500")
 | |
|             if "default" not in responses:
 | |
|                 responses["default"] = ResponseRef("DefaultError")
 | |
| 
 | |
|         if op.event_type is not None:
 | |
|             builder = ContentBuilder(self.schema_builder)
 | |
|             callbacks = {
 | |
|                 f"{op.func_name}_callback": {
 | |
|                     "{$request.query.callback}": PathItem(
 | |
|                         post=Operation(
 | |
|                             requestBody=RequestBody(
 | |
|                                 content=builder.build_content(op.event_type)
 | |
|                             ),
 | |
|                             responses={"200": Response(description="OK")},
 | |
|                         )
 | |
|                     )
 | |
|                 }
 | |
|             }
 | |
| 
 | |
|         else:
 | |
|             callbacks = None
 | |
| 
 | |
|         description = "\n".join(
 | |
|             filter(None, [doc_string.short_description, doc_string.long_description])
 | |
|         )
 | |
| 
 | |
|         return Operation(
 | |
|             tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
 | |
|             summary=None,
 | |
|             # summary=doc_string.short_description,
 | |
|             description=description,
 | |
|             parameters=parameters,
 | |
|             requestBody=requestBody,
 | |
|             responses=responses,
 | |
|             callbacks=callbacks,
 | |
|             deprecated=True if "DEPRECATED" in op.func_name else None,
 | |
|             security=[] if op.public else None,
 | |
|         )
 | |
| 
 | |
|     def generate(self) -> Document:
 | |
|         paths: Dict[str, PathItem] = {}
 | |
|         endpoint_classes: Set[type] = set()
 | |
|         for op in get_endpoint_operations(
 | |
|             self.endpoint, use_examples=self.options.use_examples
 | |
|         ):
 | |
|             endpoint_classes.add(op.defining_class)
 | |
| 
 | |
|             operation = self._build_operation(op)
 | |
| 
 | |
|             if op.http_method is HTTPMethod.GET:
 | |
|                 pathItem = PathItem(get=operation)
 | |
|             elif op.http_method is HTTPMethod.PUT:
 | |
|                 pathItem = PathItem(put=operation)
 | |
|             elif op.http_method is HTTPMethod.POST:
 | |
|                 pathItem = PathItem(post=operation)
 | |
|             elif op.http_method is HTTPMethod.DELETE:
 | |
|                 pathItem = PathItem(delete=operation)
 | |
|             elif op.http_method is HTTPMethod.PATCH:
 | |
|                 pathItem = PathItem(patch=operation)
 | |
|             else:
 | |
|                 raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
 | |
| 
 | |
|             route = op.get_route()
 | |
|             route = route.replace(":path", "")
 | |
|             print(f"route: {route}")
 | |
|             if route in paths:
 | |
|                 paths[route].update(pathItem)
 | |
|             else:
 | |
|                 paths[route] = pathItem
 | |
| 
 | |
|         operation_tags: List[Tag] = []
 | |
|         for cls in endpoint_classes:
 | |
|             doc_string = parse_type(cls)
 | |
|             if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
 | |
|                 continue
 | |
|             operation_tags.append(
 | |
|                 Tag(
 | |
|                     name=cls.__name__,
 | |
|                     description=doc_string.long_description,
 | |
|                     displayName=doc_string.short_description,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         # types that are emitted by events
 | |
|         event_tags: List[Tag] = []
 | |
|         events = get_endpoint_events(self.endpoint)
 | |
|         for ref, event_type in events.items():
 | |
|             event_schema = self.schema_builder.classdef_to_named_schema(ref, event_type)
 | |
|             event_tags.append(self._build_type_tag(ref, event_schema))
 | |
| 
 | |
|         # types that are explicitly declared
 | |
|         extra_tag_groups: Dict[str, List[Tag]] = {}
 | |
|         if self.options.extra_types is not None:
 | |
|             if isinstance(self.options.extra_types, list):
 | |
|                 extra_tag_groups = self._build_extra_tag_groups(
 | |
|                     {"AdditionalTypes": self.options.extra_types}
 | |
|                 )
 | |
|             elif isinstance(self.options.extra_types, dict):
 | |
|                 extra_tag_groups = self._build_extra_tag_groups(
 | |
|                     self.options.extra_types
 | |
|                 )
 | |
|             else:
 | |
|                 raise TypeError(
 | |
|                     f"type mismatch for collection of extra types: {type(self.options.extra_types)}"
 | |
|                 )
 | |
| 
 | |
|         # list all operations and types
 | |
|         tags: List[Tag] = []
 | |
|         tags.extend(operation_tags)
 | |
|         tags.extend(event_tags)
 | |
|         for extra_tag_group in extra_tag_groups.values():
 | |
|             tags.extend(extra_tag_group)
 | |
| 
 | |
|         tags = sorted(tags, key=lambda t: t.name)
 | |
| 
 | |
|         tag_groups = []
 | |
|         if operation_tags:
 | |
|             tag_groups.append(
 | |
|                 TagGroup(
 | |
|                     name=self.options.map("Operations"),
 | |
|                     tags=sorted(tag.name for tag in operation_tags),
 | |
|                 )
 | |
|             )
 | |
|         if event_tags:
 | |
|             tag_groups.append(
 | |
|                 TagGroup(
 | |
|                     name=self.options.map("Events"),
 | |
|                     tags=sorted(tag.name for tag in event_tags),
 | |
|                 )
 | |
|             )
 | |
|         for caption, extra_tag_group in extra_tag_groups.items():
 | |
|             tag_groups.append(
 | |
|                 TagGroup(
 | |
|                     name=caption,
 | |
|                     tags=sorted(tag.name for tag in extra_tag_group),
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if self.options.default_security_scheme:
 | |
|             securitySchemes = {"Default": self.options.default_security_scheme}
 | |
|         else:
 | |
|             securitySchemes = None
 | |
| 
 | |
|         return Document(
 | |
|             openapi=".".join(str(item) for item in self.options.version),
 | |
|             info=self.options.info,
 | |
|             jsonSchemaDialect=(
 | |
|                 "https://json-schema.org/draft/2020-12/schema"
 | |
|                 if self.options.version >= (3, 1, 0)
 | |
|                 else None
 | |
|             ),
 | |
|             servers=[self.options.server],
 | |
|             paths=paths,
 | |
|             components=Components(
 | |
|                 schemas=self.schema_builder.schemas,
 | |
|                 responses=self.responses,
 | |
|                 securitySchemes=securitySchemes,
 | |
|             ),
 | |
|             security=[{"Default": []}],
 | |
|             tags=tags,
 | |
|             tagGroups=tag_groups,
 | |
|         )
 |