Fix OpenAPI generation to have text/event-stream for streamable methods

This commit is contained in:
Ashwin Bharambe 2024-11-14 12:51:38 -08:00
parent acbecbf8b3
commit bba6edd06b
4 changed files with 703 additions and 705 deletions

View file

@ -34,20 +34,6 @@ schema_utils.json_schema_type = json_schema_type
from llama_stack.distribution.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = [
"/agents/turn/create",
"/inference/chat_completion",
]
def patch_sse_stream_responses(spec: Specification):
for path, path_item in spec.document.paths.items():
if path in STREAMING_ENDPOINTS:
content = path_item.post.responses["200"].content.pop("application/json")
path_item.post.responses["200"].content["text/event-stream"] = content
def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():
@ -74,8 +60,6 @@ def main(output_dir: str):
),
)
patch_sse_stream_responses(spec)
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import collections
import hashlib
import ipaddress
import typing
@ -176,9 +177,20 @@ class ContentBuilder:
) -> Dict[str, MediaType]:
"Creates the content subtree for a request or response."
def has_iterator_type(t):
if typing.get_origin(t) is typing.Union:
return any(has_iterator_type(a) for a in typing.get_args(t))
else:
# TODO: needs a proper fix where we let all types correctly flow upwards
# and then test against AsyncIterator
return "StreamChunk" in str(t)
if is_generic_list(payload_type):
media_type = "application/jsonl"
item_type = unwrap_generic_list(payload_type)
elif has_iterator_type(payload_type):
item_type = payload_type
media_type = "text/event-stream"
else:
media_type = "application/json"
item_type = payload_type
@ -671,6 +683,8 @@ class Generator:
for extra_tag_group in extra_tag_groups.values():
tags.extend(extra_tag_group)
tags = sorted(tags, key=lambda t: t.name)
tag_groups = []
if operation_tags:
tag_groups.append(