mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(openapi): switch to fastapi-based generator (#3944)
Some checks failed
Pre-commit / pre-commit (push) Successful in 3m27s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / show-single-provider (push) Successful in 25s
Test External API and Providers / test-external (venv) (push) Failing after 34s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test Llama Stack Build / build (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Successful in 48s
Test llama stack list-deps / list-deps-from-config (push) Successful in 52s
Test llama stack list-deps / list-deps (push) Failing after 52s
Python Package Build Test / build (3.13) (push) Failing after 1m2s
UI Tests / ui-tests (22) (push) Successful in 1m15s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 1m29s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m54s
Unit Tests / unit-tests (3.13) (push) Failing after 2m13s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m20s
Some checks failed
Pre-commit / pre-commit (push) Successful in 3m27s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / show-single-provider (push) Successful in 25s
Test External API and Providers / test-external (venv) (push) Failing after 34s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test Llama Stack Build / build (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Successful in 48s
Test llama stack list-deps / list-deps-from-config (push) Successful in 52s
Test llama stack list-deps / list-deps (push) Failing after 52s
Python Package Build Test / build (3.13) (push) Failing after 1m2s
UI Tests / ui-tests (22) (push) Successful in 1m15s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 1m29s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m54s
Unit Tests / unit-tests (3.13) (push) Failing after 2m13s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m20s
# What does this PR do?
This replaces the legacy "pyopenapi + strong_typing" pipeline with a
FastAPI-backed generator that has an explicit schema registry inside
`llama_stack_api`. The key changes:
1. **New generator architecture.** FastAPI now builds the OpenAPI schema
directly from the real routes, while helper modules
(`schema_collection`, `endpoints`, `schema_transforms`, etc.)
post-process the result. The old pyopenapi stack and its strong_typing
helpers are removed entirely, so we no longer rely on fragile AST
analysis or top-level import side effects.
2. **Schema registry in `llama_stack_api`.** `schema_utils.py` keeps a
`SchemaInfo` record for every `@json_schema_type`, `register_schema`,
and dynamically created request model. The OpenAPI generator and other
tooling query this registry instead of scanning the package tree,
producing deterministic names (e.g., `{MethodName}Request`), capturing
all optional/nullable fields, and making schema discovery testable. A
new unit test covers the registry behavior.
3. **Regenerated specs + CI alignment.** All docs/Stainless specs are
regenerated from the new pipeline, so optional/nullable fields now match
reality (expect the API Conformance workflow to report breaking
changes—this PR establishes the new baseline). The workflow itself is
back to the stock oasdiff invocation so future regressions surface
normally.
*Conformance will be RED on this PR; we choose to accept the
deviations.*
## Test Plan
- `uv run pytest tests/unit/server/test_schema_registry.py`
- `uv run python -m scripts.openapi_generator.main docs/static`
---------
Signed-off-by: Sébastien Han <seb@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
cc88789071
commit
97f535c4f1
64 changed files with 47592 additions and 30218 deletions
|
|
@ -1 +0,0 @@
|
|||
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack.core/server/endpoints.py` using the `generate.py` utility.
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import fire
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
from llama_stack_api import LLAMA_STACK_API_V1 # noqa: E402
|
||||
from llama_stack.core.stack import LlamaStack # noqa: E402
|
||||
|
||||
from .pyopenapi.options import Options # noqa: E402
|
||||
from .pyopenapi.specification import Info, Server # noqa: E402
|
||||
from .pyopenapi.utility import Specification, validate_api # noqa: E402
|
||||
|
||||
|
||||
def str_presenter(dumper, data):
|
||||
if data.startswith(f"/{LLAMA_STACK_API_V1}") or data.startswith(
|
||||
"#/components/schemas/"
|
||||
):
|
||||
style = None
|
||||
else:
|
||||
style = ">" if "\n" in data or len(data) > 40 else None
|
||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style=style)
|
||||
|
||||
|
||||
def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: bool = False, combined_spec: bool = False):
|
||||
"""Generate OpenAPI spec with optional stability filtering."""
|
||||
|
||||
if combined_spec:
|
||||
# Special case for combined stable + experimental APIs
|
||||
title_suffix = " - Stable & Experimental APIs"
|
||||
filename_prefix = "stainless-"
|
||||
description_suffix = "\n\n**🔗 COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. Use stable APIs for production deployments and experimental APIs for testing new features."
|
||||
# Use the special "stainless" filter to include stable + experimental APIs
|
||||
stability_filter = "stainless"
|
||||
elif stability_filter:
|
||||
title_suffix = {
|
||||
"stable": " - Stable APIs" if not main_spec else "",
|
||||
"experimental": " - Experimental APIs",
|
||||
"deprecated": " - Deprecated APIs"
|
||||
}.get(stability_filter, f" - {stability_filter.title()} APIs")
|
||||
|
||||
# Use main spec filename for stable when main_spec=True
|
||||
if main_spec and stability_filter == "stable":
|
||||
filename_prefix = ""
|
||||
else:
|
||||
filename_prefix = f"{stability_filter}-"
|
||||
|
||||
description_suffix = {
|
||||
"stable": "\n\n**✅ STABLE**: Production-ready APIs with backward compatibility guarantees.",
|
||||
"experimental": "\n\n**🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before becoming stable.",
|
||||
"deprecated": "\n\n**⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for migration reference only."
|
||||
}.get(stability_filter, "")
|
||||
else:
|
||||
title_suffix = ""
|
||||
filename_prefix = ""
|
||||
description_suffix = ""
|
||||
|
||||
spec = Specification(
|
||||
LlamaStack,
|
||||
Options(
|
||||
server=Server(url="http://any-hosted-llama-stack.com"),
|
||||
info=Info(
|
||||
title=f"Llama Stack Specification{title_suffix}",
|
||||
version=LLAMA_STACK_API_V1,
|
||||
description=f"""This is the specification of the Llama Stack that provides
|
||||
a set of endpoints and their corresponding interfaces that are tailored to
|
||||
best leverage Llama Models.{description_suffix}""",
|
||||
),
|
||||
include_standard_error_responses=True,
|
||||
stability_filter=stability_filter, # Pass the filter to the generator
|
||||
),
|
||||
)
|
||||
|
||||
yaml_filename = f"{filename_prefix}llama-stack-spec.yaml"
|
||||
|
||||
with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp:
|
||||
y = yaml.YAML()
|
||||
y.default_flow_style = False
|
||||
y.block_seq_indent = 2
|
||||
y.map_indent = 2
|
||||
y.sequence_indent = 4
|
||||
y.sequence_dash_offset = 2
|
||||
y.width = 80
|
||||
y.allow_unicode = True
|
||||
y.representer.add_representer(str, str_presenter)
|
||||
|
||||
y.dump(
|
||||
spec.get_json(),
|
||||
fp,
|
||||
)
|
||||
|
||||
def main(output_dir: str):
|
||||
output_dir = Path(output_dir)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"Directory {output_dir} does not exist")
|
||||
|
||||
# Validate API protocols before generating spec
|
||||
return_type_errors = validate_api()
|
||||
if return_type_errors:
|
||||
print("\nAPI Method Return Type Validation Errors:\n")
|
||||
for error in return_type_errors:
|
||||
print(error, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
now = str(datetime.now())
|
||||
print(f"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at {now}")
|
||||
print("")
|
||||
|
||||
# Generate main spec as stable APIs (llama-stack-spec.yaml)
|
||||
print("Generating main specification (stable APIs)...")
|
||||
generate_spec(output_dir, "stable", main_spec=True)
|
||||
|
||||
print("Generating other stability-filtered specifications...")
|
||||
generate_spec(output_dir, "experimental")
|
||||
generate_spec(output_dir, "deprecated")
|
||||
|
||||
print("Generating combined stable + experimental specification...")
|
||||
generate_spec(output_dir, combined_spec=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
|
@ -1 +0,0 @@
|
|||
This is forked from https://github.com/hunyadi/pyopenapi
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,459 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import collections.abc
|
||||
import enum
|
||||
import inspect
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from typing import get_origin, get_args
|
||||
|
||||
from fastapi import UploadFile
|
||||
from fastapi.params import File, Form
|
||||
from typing import Annotated
|
||||
|
||||
from llama_stack_api import (
|
||||
ExtraBodyField,
|
||||
LLAMA_STACK_API_V1,
|
||||
LLAMA_STACK_API_V1ALPHA,
|
||||
LLAMA_STACK_API_V1BETA,
|
||||
get_signature,
|
||||
)
|
||||
|
||||
|
||||
def split_prefix(
|
||||
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Recognizes a prefix at the beginning of a string.
|
||||
|
||||
:param s: The string to check.
|
||||
:param sep: A separator between (one of) the prefix(es) and the rest of the string.
|
||||
:param prefix: A string or a set of strings to identify as a prefix.
|
||||
:return: A tuple of the recognized prefix (if any) and the rest of the string excluding the separator (or the entire string).
|
||||
"""
|
||||
|
||||
if isinstance(prefix, str):
|
||||
if s.startswith(prefix + sep):
|
||||
return prefix, s[len(prefix) + len(sep) :]
|
||||
else:
|
||||
return None, s
|
||||
|
||||
for p in prefix:
|
||||
if s.startswith(p + sep):
|
||||
return p, s[len(p) + len(sep) :]
|
||||
|
||||
return None, s
|
||||
|
||||
|
||||
def _get_annotation_type(annotation: Union[type, str], callable: Callable) -> type:
|
||||
"Maps a stringized reference to a type, as if using `from __future__ import annotations`."
|
||||
|
||||
if isinstance(annotation, str):
|
||||
return eval(annotation, callable.__globals__)
|
||||
else:
|
||||
return annotation
|
||||
|
||||
|
||||
class HTTPMethod(enum.Enum):
|
||||
"HTTP method used to invoke an endpoint operation."
|
||||
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
OperationParameter = Tuple[str, type]
|
||||
|
||||
|
||||
class ValidationError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointOperation:
|
||||
"""
|
||||
Type information and metadata associated with an endpoint operation.
|
||||
|
||||
"param defining_class: The most specific class that defines the endpoint operation.
|
||||
:param name: The short name of the endpoint operation.
|
||||
:param func_name: The name of the function to invoke when the operation is triggered.
|
||||
:param func_ref: The callable to invoke when the operation is triggered.
|
||||
:param route: A custom route string assigned to the operation.
|
||||
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
|
||||
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
||||
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
||||
:param multipart_params: Parameters that indicate multipart/form-data request body.
|
||||
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
|
||||
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
||||
:param response_type: The Python type of the data that is transmitted in the response body.
|
||||
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
||||
:param public: True if the operation can be invoked without prior authentication.
|
||||
:param request_examples: Sample requests that the operation might take.
|
||||
:param response_examples: Sample responses that the operation might produce.
|
||||
"""
|
||||
|
||||
defining_class: type
|
||||
name: str
|
||||
func_name: str
|
||||
func_ref: Callable[..., Any]
|
||||
route: Optional[str]
|
||||
path_params: List[OperationParameter]
|
||||
query_params: List[OperationParameter]
|
||||
request_params: Optional[OperationParameter]
|
||||
multipart_params: List[OperationParameter]
|
||||
extra_body_params: List[tuple[str, type, str | None]]
|
||||
event_type: Optional[type]
|
||||
response_type: type
|
||||
http_method: HTTPMethod
|
||||
public: bool
|
||||
request_examples: Optional[List[Any]] = None
|
||||
response_examples: Optional[List[Any]] = None
|
||||
|
||||
def get_route(self, webmethod) -> str:
|
||||
api_level = webmethod.level
|
||||
|
||||
if self.route is not None:
|
||||
return "/".join(["", api_level, self.route.lstrip("/")])
|
||||
|
||||
route_parts = ["", api_level, self.name]
|
||||
for param_name, _ in self.path_params:
|
||||
route_parts.append("{" + param_name + "}")
|
||||
return "/".join(route_parts)
|
||||
|
||||
|
||||
class _FormatParameterExtractor:
|
||||
"A visitor to exract parameters in a format string."
|
||||
|
||||
keys: List[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.keys = []
|
||||
|
||||
def __getitem__(self, key: str) -> None:
|
||||
self.keys.append(key)
|
||||
return None
|
||||
|
||||
|
||||
def _get_route_parameters(route: str) -> List[str]:
|
||||
extractor = _FormatParameterExtractor()
|
||||
# Replace all occurrences of ":path" with empty string
|
||||
route = route.replace(":path", "")
|
||||
route.format_map(extractor)
|
||||
return extractor.keys
|
||||
|
||||
|
||||
def _get_endpoint_functions(
|
||||
endpoint: type, prefixes: List[str]
|
||||
) -> Iterator[Tuple[str, str, str, Callable]]:
|
||||
if not inspect.isclass(endpoint):
|
||||
raise ValueError(f"object is not a class type: {endpoint}")
|
||||
|
||||
functions = inspect.getmembers(endpoint, inspect.isfunction)
|
||||
for func_name, func_ref in functions:
|
||||
webmethods = []
|
||||
|
||||
# Check for multiple webmethods (stacked decorators)
|
||||
if hasattr(func_ref, "__webmethods__"):
|
||||
webmethods = func_ref.__webmethods__
|
||||
|
||||
if not webmethods:
|
||||
continue
|
||||
|
||||
for webmethod in webmethods:
|
||||
print(f"Processing {colored(func_name, 'white')}...")
|
||||
operation_name = func_name
|
||||
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
elif webmethod.method == "POST":
|
||||
prefix = "post"
|
||||
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
prefix = "get"
|
||||
elif (
|
||||
operation_name.startswith("delete_")
|
||||
or operation_name.startswith("remove_")
|
||||
or operation_name.endswith("/delete")
|
||||
or operation_name.endswith("/remove")
|
||||
):
|
||||
prefix = "delete"
|
||||
else:
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
|
||||
yield prefix, operation_name, func_name, func_ref
|
||||
|
||||
|
||||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||
|
||||
# iterate in reverse member resolution order to find most specific class first
|
||||
for cls in reversed(inspect.getmro(derived_cls)):
|
||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||
if name == member_fn:
|
||||
return cls
|
||||
|
||||
raise ValidationError(
|
||||
f"cannot find defining class for {member_fn} in {derived_cls}"
|
||||
)
|
||||
|
||||
|
||||
def get_endpoint_operations(
|
||||
endpoint: type, use_examples: bool = True
|
||||
) -> List[EndpointOperation]:
|
||||
"""
|
||||
Extracts a list of member functions in a class eligible for HTTP interface binding.
|
||||
|
||||
These member functions are expected to have a signature like
|
||||
```
|
||||
async def get_object(self, uuid: str, version: int) -> Object:
|
||||
...
|
||||
```
|
||||
where the prefix `get_` translates to an HTTP GET, `object` corresponds to the name of the endpoint operation,
|
||||
`uuid` and `version` are mapped to route path elements in "/object/{uuid}/{version}", and `Object` becomes
|
||||
the response payload type, transmitted as an object serialized to JSON.
|
||||
|
||||
If the member function has a composite class type in the argument list, it becomes the request payload type,
|
||||
and the caller is expected to provide the data as serialized JSON in an HTTP POST request.
|
||||
|
||||
:param endpoint: A class with member functions that can be mapped to an HTTP endpoint.
|
||||
:param use_examples: Whether to return examples associated with member functions.
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
for prefix, operation_name, func_name, func_ref in _get_endpoint_functions(
|
||||
endpoint,
|
||||
[
|
||||
"create",
|
||||
"delete",
|
||||
"do",
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"remove",
|
||||
"set",
|
||||
"update",
|
||||
],
|
||||
):
|
||||
# Get all webmethods for this function
|
||||
webmethods = getattr(func_ref, "__webmethods__", [])
|
||||
|
||||
# Create one EndpointOperation for each webmethod
|
||||
for webmethod in webmethods:
|
||||
route = webmethod.route
|
||||
route_params = _get_route_parameters(route) if route is not None else None
|
||||
public = webmethod.public
|
||||
request_examples = webmethod.request_examples
|
||||
response_examples = webmethod.response_examples
|
||||
|
||||
# inspect function signature for path and query parameters, and request/response payload type
|
||||
signature = get_signature(func_ref)
|
||||
|
||||
path_params = []
|
||||
query_params = []
|
||||
request_params = []
|
||||
multipart_params = []
|
||||
extra_body_params = []
|
||||
|
||||
for param_name, parameter in signature.parameters.items():
|
||||
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
||||
|
||||
# omit "self" for instance methods
|
||||
if param_name == "self" and param_type is inspect.Parameter.empty:
|
||||
continue
|
||||
|
||||
# check if all parameters have explicit type
|
||||
if parameter.annotation is inspect.Parameter.empty:
|
||||
raise ValidationError(
|
||||
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
||||
)
|
||||
|
||||
# Check if this is an extra_body parameter
|
||||
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
|
||||
if is_extra_body:
|
||||
# Store in a separate list for documentation
|
||||
extra_body_params.append((param_name, param_type, extra_body_desc))
|
||||
continue # Skip adding to request_params
|
||||
|
||||
is_multipart = _is_multipart_param(param_type)
|
||||
|
||||
if prefix in ["get", "delete"]:
|
||||
if route_params is not None and param_name in route_params:
|
||||
path_params.append((param_name, param_type))
|
||||
else:
|
||||
query_params.append((param_name, param_type))
|
||||
else:
|
||||
if route_params is not None and param_name in route_params:
|
||||
path_params.append((param_name, param_type))
|
||||
elif is_multipart:
|
||||
multipart_params.append((param_name, param_type))
|
||||
else:
|
||||
request_params.append((param_name, param_type))
|
||||
|
||||
# check if function has explicit return type
|
||||
if signature.return_annotation is inspect.Signature.empty:
|
||||
raise ValidationError(
|
||||
f"function '{func_name}' has no return type annotation"
|
||||
)
|
||||
|
||||
return_type = _get_annotation_type(signature.return_annotation, func_ref)
|
||||
|
||||
# operations that produce events are labeled as Generator[YieldType, SendType, ReturnType]
|
||||
# where YieldType is the event type, SendType is None, and ReturnType is the immediate response type to the request
|
||||
if typing.get_origin(return_type) is collections.abc.Generator:
|
||||
event_type, send_type, response_type = typing.get_args(return_type)
|
||||
if send_type is not type(None):
|
||||
raise ValidationError(
|
||||
f"function '{func_name}' has a return type Generator[Y,S,R] and therefore looks like an event but has an explicit send type"
|
||||
)
|
||||
else:
|
||||
event_type = None
|
||||
|
||||
def process_type(t):
|
||||
if typing.get_origin(t) is collections.abc.AsyncIterator:
|
||||
# NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
|
||||
# or the item type. I am choosing it to be the latter
|
||||
args = typing.get_args(t)
|
||||
return args[0]
|
||||
elif typing.get_origin(t) is typing.Union:
|
||||
types = [process_type(a) for a in typing.get_args(t)]
|
||||
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
||||
else:
|
||||
return t
|
||||
|
||||
response_type = process_type(return_type)
|
||||
|
||||
if prefix in ["delete", "remove"]:
|
||||
http_method = HTTPMethod.DELETE
|
||||
elif prefix == "post":
|
||||
http_method = HTTPMethod.POST
|
||||
elif prefix == "get":
|
||||
http_method = HTTPMethod.GET
|
||||
elif prefix == "set":
|
||||
http_method = HTTPMethod.PUT
|
||||
elif prefix == "update":
|
||||
http_method = HTTPMethod.PATCH
|
||||
else:
|
||||
raise ValidationError(f"unknown prefix {prefix}")
|
||||
|
||||
# Create an EndpointOperation for this specific webmethod
|
||||
operation = EndpointOperation(
|
||||
defining_class=_get_defining_class(func_name, endpoint),
|
||||
name=operation_name,
|
||||
func_name=func_name,
|
||||
func_ref=func_ref,
|
||||
route=route,
|
||||
path_params=path_params,
|
||||
query_params=query_params,
|
||||
request_params=request_params,
|
||||
multipart_params=multipart_params,
|
||||
extra_body_params=extra_body_params,
|
||||
event_type=event_type,
|
||||
response_type=response_type,
|
||||
http_method=http_method,
|
||||
public=public,
|
||||
request_examples=request_examples if use_examples else None,
|
||||
response_examples=response_examples if use_examples else None,
|
||||
)
|
||||
|
||||
# Store the specific webmethod with this operation
|
||||
operation.webmethod = webmethod
|
||||
result.append(operation)
|
||||
|
||||
if not result:
|
||||
raise ValidationError(f"no eligible endpoint operations in type {endpoint}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_endpoint_events(endpoint: type) -> Dict[str, type]:
|
||||
results = {}
|
||||
|
||||
for decl in typing.get_type_hints(endpoint).values():
|
||||
# check if signature is Callable[...]
|
||||
origin = typing.get_origin(decl)
|
||||
if origin is None or not issubclass(origin, Callable): # type: ignore
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[...], Any]
|
||||
args = typing.get_args(decl)
|
||||
if len(args) != 2:
|
||||
continue
|
||||
params_type, return_type = args
|
||||
if not isinstance(params_type, list):
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[...], None]
|
||||
if not issubclass(return_type, type(None)):
|
||||
continue
|
||||
|
||||
# check if signature is Callable[[EventType], None]
|
||||
if len(params_type) != 1:
|
||||
continue
|
||||
|
||||
param_type = params_type[0]
|
||||
results[param_type.__name__] = param_type
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _is_multipart_param(param_type: type) -> bool:
|
||||
"""
|
||||
Check if a parameter type indicates multipart form data.
|
||||
|
||||
Returns True if the type is:
|
||||
- UploadFile
|
||||
- Annotated[UploadFile, File()]
|
||||
- Annotated[str, Form()]
|
||||
- Annotated[Any, File()]
|
||||
- Annotated[Any, Form()]
|
||||
"""
|
||||
if param_type is UploadFile:
|
||||
return True
|
||||
|
||||
# Check for Annotated types
|
||||
origin = get_origin(param_type)
|
||||
if origin is None:
|
||||
return False
|
||||
|
||||
if origin is Annotated:
|
||||
args = get_args(param_type)
|
||||
if len(args) < 2:
|
||||
return False
|
||||
|
||||
# Check the annotations for File() or Form()
|
||||
for annotation in args[1:]:
|
||||
if isinstance(annotation, (File, Form)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if parameter is marked as coming from extra_body.
|
||||
|
||||
Returns:
|
||||
(is_extra_body, description): Tuple of boolean and optional description
|
||||
"""
|
||||
origin = get_origin(param_type)
|
||||
if origin is Annotated:
|
||||
args = get_args(param_type)
|
||||
for annotation in args[1:]:
|
||||
if isinstance(annotation, ExtraBodyField):
|
||||
return True, annotation.description
|
||||
# Also check by type name for cases where import matters
|
||||
if type(annotation).__name__ == 'ExtraBodyField':
|
||||
return True, getattr(annotation, 'description', None)
|
||||
return False, None
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from .specification import (
|
||||
Info,
|
||||
SecurityScheme,
|
||||
SecuritySchemeAPI,
|
||||
SecuritySchemeHTTP,
|
||||
SecuritySchemeOpenIDConnect,
|
||||
Server,
|
||||
)
|
||||
|
||||
HTTPStatusCode = Union[HTTPStatus, int, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Options:
|
||||
"""
|
||||
:param server: Base URL for the API endpoint.
|
||||
:param info: Meta-information for the endpoint specification.
|
||||
:param version: OpenAPI specification version as a tuple of major, minor, revision.
|
||||
:param default_security_scheme: Security scheme to apply to endpoints, unless overridden on a per-endpoint basis.
|
||||
:param extra_types: Extra types in addition to those found in operation signatures. Use a dictionary to group related types.
|
||||
:param use_examples: Whether to emit examples for operations.
|
||||
:param success_responses: Associates operation response types with HTTP status codes.
|
||||
:param error_responses: Associates error response types with HTTP status codes.
|
||||
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
||||
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
||||
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
||||
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
|
||||
"""
|
||||
|
||||
server: Server
|
||||
info: Info
|
||||
version: Tuple[int, int, int] = (3, 1, 0)
|
||||
default_security_scheme: Optional[SecurityScheme] = None
|
||||
extra_types: Union[List[type], Dict[str, List[type]], None] = None
|
||||
use_examples: bool = True
|
||||
success_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
error_responses: Dict[type, HTTPStatusCode] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
error_wrapper: bool = False
|
||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
captions: Optional[Dict[str, str]] = None
|
||||
include_standard_error_responses: bool = True
|
||||
stability_filter: Optional[str] = None
|
||||
|
||||
default_captions: ClassVar[Dict[str, str]] = {
|
||||
"Operations": "Operations",
|
||||
"Types": "Types",
|
||||
"Events": "Events",
|
||||
"AdditionalTypes": "Additional types",
|
||||
}
|
||||
|
||||
def map(self, id: str) -> str:
|
||||
"Maps a language-neutral placeholder string to language-dependent text."
|
||||
|
||||
if self.captions is not None:
|
||||
caption = self.captions.get(id)
|
||||
if caption is not None:
|
||||
return caption
|
||||
|
||||
caption = self.__class__.default_captions.get(id)
|
||||
if caption is not None:
|
||||
return caption
|
||||
|
||||
raise KeyError(f"no caption found for ID: {id}")
|
||||
|
|
@ -1,269 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack_api import JsonType, Schema, StrictJsonType
|
||||
|
||||
URL = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ref:
|
||||
ref_type: ClassVar[str]
|
||||
id: str
|
||||
|
||||
def to_json(self) -> StrictJsonType:
|
||||
return {"$ref": f"#/components/{self.ref_type}/{self.id}"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaRef(Ref):
|
||||
ref_type: ClassVar[str] = "schemas"
|
||||
|
||||
|
||||
SchemaOrRef = Union[Schema, SchemaRef]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseRef(Ref):
|
||||
ref_type: ClassVar[str] = "responses"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterRef(Ref):
|
||||
ref_type: ClassVar[str] = "parameters"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleRef(Ref):
|
||||
ref_type: ClassVar[str] = "examples"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Contact:
|
||||
name: Optional[str] = None
|
||||
url: Optional[URL] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class License:
|
||||
name: str
|
||||
url: Optional[URL] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Info:
|
||||
title: str
|
||||
version: str
|
||||
description: Optional[str] = None
|
||||
termsOfService: Optional[str] = None
|
||||
contact: Optional[Contact] = None
|
||||
license: Optional[License] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaType:
|
||||
schema: Optional[SchemaOrRef] = None
|
||||
example: Optional[Any] = None
|
||||
examples: Optional[Dict[str, Union["Example", ExampleRef]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestBody:
|
||||
content: Dict[str, MediaType | Dict[str, Any]]
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
description: str
|
||||
content: Optional[Dict[str, MediaType]] = None
|
||||
|
||||
|
||||
class ParameterLocation(enum.Enum):
|
||||
Query = "query"
|
||||
Header = "header"
|
||||
Path = "path"
|
||||
Cookie = "cookie"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameter:
|
||||
name: str
|
||||
in_: ParameterLocation
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
schema: Optional[SchemaOrRef] = None
|
||||
example: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraBodyParameter:
|
||||
"""Represents a parameter that arrives via extra_body in the request."""
|
||||
name: str
|
||||
schema: SchemaOrRef
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
responses: Dict[str, Union[Response, ResponseRef]]
|
||||
tags: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
operationId: Optional[str] = None
|
||||
parameters: Optional[List[Parameter]] = None
|
||||
requestBody: Optional[RequestBody] = None
|
||||
callbacks: Optional[Dict[str, "Callback"]] = None
|
||||
security: Optional[List["SecurityRequirement"]] = None
|
||||
deprecated: Optional[bool] = None
|
||||
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathItem:
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
get: Optional[Operation] = None
|
||||
put: Optional[Operation] = None
|
||||
post: Optional[Operation] = None
|
||||
delete: Optional[Operation] = None
|
||||
options: Optional[Operation] = None
|
||||
head: Optional[Operation] = None
|
||||
patch: Optional[Operation] = None
|
||||
trace: Optional[Operation] = None
|
||||
|
||||
def update(self, other: "PathItem") -> None:
|
||||
"Merges another instance of this class into this object."
|
||||
|
||||
for field in dataclasses.fields(self.__class__):
|
||||
value = getattr(other, field.name)
|
||||
if value is not None:
|
||||
setattr(self, field.name, value)
|
||||
|
||||
|
||||
# maps run-time expressions such as "$request.body#/url" to path items
|
||||
Callback = Dict[str, PathItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Example:
|
||||
summary: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
value: Optional[Any] = None
|
||||
externalValue: Optional[URL] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Server:
|
||||
url: URL
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeType(enum.Enum):
|
||||
ApiKey = "apiKey"
|
||||
HTTP = "http"
|
||||
OAuth2 = "oauth2"
|
||||
OpenIDConnect = "openIdConnect"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityScheme:
|
||||
type: SecuritySchemeType
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeAPI(SecurityScheme):
|
||||
name: str
|
||||
in_: ParameterLocation
|
||||
|
||||
def __init__(self, description: str, name: str, in_: ParameterLocation) -> None:
|
||||
super().__init__(SecuritySchemeType.ApiKey, description)
|
||||
self.name = name
|
||||
self.in_ = in_
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeHTTP(SecurityScheme):
|
||||
scheme: str
|
||||
bearerFormat: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self, description: str, scheme: str, bearerFormat: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__(SecuritySchemeType.HTTP, description)
|
||||
self.scheme = scheme
|
||||
self.bearerFormat = bearerFormat
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SecuritySchemeOpenIDConnect(SecurityScheme):
|
||||
openIdConnectUrl: str
|
||||
|
||||
def __init__(self, description: str, openIdConnectUrl: str) -> None:
|
||||
super().__init__(SecuritySchemeType.OpenIDConnect, description)
|
||||
self.openIdConnectUrl = openIdConnectUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class Components:
|
||||
schemas: Optional[Dict[str, Schema]] = None
|
||||
responses: Optional[Dict[str, Response]] = None
|
||||
parameters: Optional[Dict[str, Parameter]] = None
|
||||
examples: Optional[Dict[str, Example]] = None
|
||||
requestBodies: Optional[Dict[str, RequestBody]] = None
|
||||
securitySchemes: Optional[Dict[str, SecurityScheme]] = None
|
||||
callbacks: Optional[Dict[str, Callback]] = None
|
||||
|
||||
|
||||
SecurityScope = str
|
||||
SecurityRequirement = Dict[str, List[SecurityScope]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tag:
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
displayName: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TagGroup:
|
||||
"""
|
||||
A ReDoc extension to provide information about groups of tags.
|
||||
|
||||
Exposed via the vendor-specific property "x-tagGroups" of the top-level object.
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""
|
||||
This class is a Python dataclass adaptation of the OpenAPI Specification.
|
||||
|
||||
For details, see <https://swagger.io/specification/>
|
||||
"""
|
||||
|
||||
openapi: str
|
||||
info: Info
|
||||
servers: List[Server]
|
||||
paths: Dict[str, PathItem]
|
||||
jsonSchemaDialect: Optional[str] = None
|
||||
components: Optional[Components] = None
|
||||
security: Optional[List[SecurityRequirement]] = None
|
||||
tags: Optional[List[Tag]] = None
|
||||
tagGroups: Optional[List[TagGroup]] = None
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>OpenAPI specification</title>
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
<script type="module" src="https://cdn.jsdelivr.net/npm/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@stoplight/elements/styles.min.css">
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
elements-api {
|
||||
height: 100%;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<elements-api id="openapi-container" router="hash" layout="sidebar" hideExport="true"
|
||||
hideInternal="true"></elements-api>
|
||||
|
||||
<script>
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const spec = { /* OPENAPI_SPECIFICATION */ };
|
||||
const element = document.getElementById("openapi-container");
|
||||
element.apiDescriptionDocument = spec;
|
||||
|
||||
if (spec.info && spec.info.title) {
|
||||
document.title = spec.info.title;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
|
|
@ -1,287 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import typing
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, TextIO, Union, get_type_hints, get_origin, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from llama_stack_api import StrictJsonType, is_unwrapped_body_param, object_to_json
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
|
||||
from .generator import Generator
|
||||
from .options import Options
|
||||
from .specification import Document
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class Specification:
|
||||
document: Document
|
||||
|
||||
def __init__(self, endpoint: type, options: Options):
|
||||
generator = Generator(endpoint, options)
|
||||
self.document = generator.generate()
|
||||
|
||||
def get_json(self) -> StrictJsonType:
|
||||
"""
|
||||
Returns the OpenAPI specification as a Python data type (e.g. `dict` for an object, `list` for an array).
|
||||
|
||||
The result can be serialized to a JSON string with `json.dump` or `json.dumps`.
|
||||
"""
|
||||
|
||||
json_doc = typing.cast(StrictJsonType, object_to_json(self.document))
|
||||
|
||||
if isinstance(json_doc, dict):
|
||||
# rename vendor-specific properties
|
||||
tag_groups = json_doc.pop("tagGroups", None)
|
||||
if tag_groups:
|
||||
json_doc["x-tagGroups"] = tag_groups
|
||||
tags = json_doc.get("tags")
|
||||
if tags and isinstance(tags, list):
|
||||
for tag in tags:
|
||||
if not isinstance(tag, dict):
|
||||
continue
|
||||
|
||||
display_name = tag.pop("displayName", None)
|
||||
if display_name:
|
||||
tag["x-displayName"] = display_name
|
||||
|
||||
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
|
||||
paths = json_doc.get("paths", {})
|
||||
for path_item in paths.values():
|
||||
if isinstance(path_item, dict):
|
||||
for method in ["get", "post", "put", "delete", "patch"]:
|
||||
operation = path_item.get(method)
|
||||
if operation and isinstance(operation, dict):
|
||||
extra_body_params = operation.pop("extraBodyParameters", None)
|
||||
if extra_body_params:
|
||||
operation["x-llama-stack-extra-body-params"] = extra_body_params
|
||||
|
||||
return json_doc
|
||||
|
||||
def get_json_string(self, pretty_print: bool = False) -> str:
|
||||
"""
|
||||
Returns the OpenAPI specification as a JSON string.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the output.
|
||||
"""
|
||||
|
||||
json_doc = self.get_json()
|
||||
if pretty_print:
|
||||
return json.dumps(
|
||||
json_doc, check_circular=False, ensure_ascii=False, indent=4
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
json_doc,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
def write_json(self, f: TextIO, pretty_print: bool = False) -> None:
|
||||
"""
|
||||
Writes the OpenAPI specification to a file as a JSON string.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the output.
|
||||
"""
|
||||
|
||||
json_doc = self.get_json()
|
||||
if pretty_print:
|
||||
json.dump(
|
||||
json_doc,
|
||||
f,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
else:
|
||||
json.dump(
|
||||
json_doc,
|
||||
f,
|
||||
check_circular=False,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
def write_html(self, f: TextIO, pretty_print: bool = False) -> None:
|
||||
"""
|
||||
Creates a stand-alone HTML page for the OpenAPI specification with ReDoc.
|
||||
|
||||
:param pretty_print: Whether to use line indents to beautify the JSON string in the HTML file.
|
||||
"""
|
||||
|
||||
path = THIS_DIR / "template.html"
|
||||
with path.open(encoding="utf-8", errors="strict") as html_template_file:
|
||||
html_template = html_template_file.read()
|
||||
|
||||
html = html_template.replace(
|
||||
"{ /* OPENAPI_SPECIFICATION */ }",
|
||||
self.get_json_string(pretty_print=pretty_print),
|
||||
)
|
||||
|
||||
f.write(html)
|
||||
|
||||
def is_optional_type(type_: Any) -> bool:
|
||||
"""Check if a type is Optional."""
|
||||
origin = get_origin(type_)
|
||||
args = get_args(type_)
|
||||
return origin is Optional or (origin is Union and type(None) in args)
|
||||
|
||||
|
||||
def _validate_api_method_return_type(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if is_optional_type(return_type):
|
||||
return "returns Optional type where a return value is mandatory"
|
||||
|
||||
|
||||
def _validate_api_method_doesnt_return_list(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if get_origin(return_type) is list:
|
||||
return "returns a list where a PaginatedResponse or List*Response object is expected"
|
||||
|
||||
|
||||
def _validate_api_delete_method_returns_none(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
|
||||
# Allow OpenAI endpoints to return response objects since they follow OpenAI specification
|
||||
method_name = getattr(method, '__name__', '')
|
||||
if method_name.__contains__('openai_'):
|
||||
return None
|
||||
|
||||
if return_type is not None and return_type is not type(None):
|
||||
return "does not return None where None is mandatory"
|
||||
|
||||
|
||||
def _validate_list_parameters_contain_data(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if not inspect.isclass(return_type):
|
||||
return
|
||||
|
||||
if not return_type.__name__.startswith('List'):
|
||||
return
|
||||
|
||||
if 'data' not in return_type.model_fields:
|
||||
return "does not have a mandatory data attribute containing the list of objects"
|
||||
|
||||
|
||||
def _validate_has_ellipsis(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
if "..." not in source and not "NotImplementedError" in source:
|
||||
return "does not contain ellipsis (...) in its implementation"
|
||||
|
||||
def _validate_has_return_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||
return "does not have a ':returns:' in its docstring"
|
||||
|
||||
def _validate_has_params_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
sig = inspect.signature(method)
|
||||
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
return
|
||||
|
||||
# Only check if the method has more than one parameter
|
||||
if len(sig.parameters) > 1 and ":param" not in source:
|
||||
return "does not have a ':param' in its docstring"
|
||||
|
||||
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is None and ":returns: None" in source:
|
||||
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||
|
||||
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||
docstring = inspect.getdoc(method)
|
||||
if docstring is None:
|
||||
return None
|
||||
|
||||
lines = docstring.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
_validate_api_method_return_type,
|
||||
_validate_list_parameters_contain_data,
|
||||
_validate_api_method_doesnt_return_list,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring
|
||||
],
|
||||
"POST": [
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _get_methods_by_type(protocol, method_type: str):
|
||||
members = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
return {
|
||||
method_name: method
|
||||
for method_name, method in members
|
||||
if (webmethod := getattr(method, '__webmethod__', None))
|
||||
if webmethod and webmethod.method == method_type
|
||||
}
|
||||
|
||||
|
||||
def validate_api() -> List[str]:
|
||||
"""Validate the API protocols."""
|
||||
errors = []
|
||||
protocols = api_protocol_map()
|
||||
|
||||
for target, validators in _VALIDATORS.items():
|
||||
for protocol_name, protocol in protocols.items():
|
||||
for validator in validators:
|
||||
for method_name, method in _get_methods_by_type(protocol, target).items():
|
||||
err = validator(method)
|
||||
if err:
|
||||
errors.append(f"Method {protocol_name}.{method_name} {err}")
|
||||
|
||||
return errors
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
PYTHONPATH=${PYTHONPATH:-}
|
||||
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
missing_packages=()
|
||||
|
||||
check_package() {
|
||||
if ! pip show "$1" &>/dev/null; then
|
||||
missing_packages+=("$1")
|
||||
fi
|
||||
}
|
||||
|
||||
if [ ${#missing_packages[@]} -ne 0 ]; then
|
||||
echo "Error: The following package(s) are not installed:"
|
||||
printf " - %s\n" "${missing_packages[@]}"
|
||||
echo "Please install them using:"
|
||||
echo "pip install ${missing_packages[*]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stack_dir=$(dirname $(dirname $THIS_DIR))
|
||||
PYTHONPATH=$PYTHONPATH:$stack_dir \
|
||||
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/static
|
||||
|
||||
cp $stack_dir/docs/static/stainless-llama-stack-spec.yaml $stack_dir/client-sdks/stainless/openapi.yml
|
||||
10580
docs/static/deprecated-llama-stack-spec.yaml
vendored
10580
docs/static/deprecated-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
10305
docs/static/experimental-llama-stack-spec.yaml
vendored
10305
docs/static/experimental-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
14390
docs/static/llama-stack-spec.yaml
vendored
14390
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
15875
docs/static/stainless-llama-stack-spec.yaml
vendored
15875
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue