mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Fix OpenAPI generation to have text/event-stream for streamable methods
This commit is contained in:
parent
acbecbf8b3
commit
bba6edd06b
4 changed files with 703 additions and 705 deletions
|
@ -34,20 +34,6 @@ schema_utils.json_schema_type = json_schema_type
|
|||
from llama_stack.distribution.stack import LlamaStack
|
||||
|
||||
|
||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||
STREAMING_ENDPOINTS = [
|
||||
"/agents/turn/create",
|
||||
"/inference/chat_completion",
|
||||
]
|
||||
|
||||
|
||||
def patch_sse_stream_responses(spec: Specification):
|
||||
for path, path_item in spec.document.paths.items():
|
||||
if path in STREAMING_ENDPOINTS:
|
||||
content = path_item.post.responses["200"].content.pop("application/json")
|
||||
path_item.post.responses["200"].content["text/event-stream"] = content
|
||||
|
||||
|
||||
def main(output_dir: str):
|
||||
output_dir = Path(output_dir)
|
||||
if not output_dir.exists():
|
||||
|
@ -74,8 +60,6 @@ def main(output_dir: str):
|
|||
),
|
||||
)
|
||||
|
||||
patch_sse_stream_responses(spec)
|
||||
|
||||
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
|
||||
yaml.dump(spec.get_json(), fp, allow_unicode=True)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import typing
|
||||
|
@ -176,9 +177,20 @@ class ContentBuilder:
|
|||
) -> Dict[str, MediaType]:
|
||||
"Creates the content subtree for a request or response."
|
||||
|
||||
def has_iterator_type(t):
|
||||
if typing.get_origin(t) is typing.Union:
|
||||
return any(has_iterator_type(a) for a in typing.get_args(t))
|
||||
else:
|
||||
# TODO: needs a proper fix where we let all types correctly flow upwards
|
||||
# and then test against AsyncIterator
|
||||
return "StreamChunk" in str(t)
|
||||
|
||||
if is_generic_list(payload_type):
|
||||
media_type = "application/jsonl"
|
||||
item_type = unwrap_generic_list(payload_type)
|
||||
elif has_iterator_type(payload_type):
|
||||
item_type = payload_type
|
||||
media_type = "text/event-stream"
|
||||
else:
|
||||
media_type = "application/json"
|
||||
item_type = payload_type
|
||||
|
@ -671,6 +683,8 @@ class Generator:
|
|||
for extra_tag_group in extra_tag_groups.values():
|
||||
tags.extend(extra_tag_group)
|
||||
|
||||
tags = sorted(tags, key=lambda t: t.name)
|
||||
|
||||
tag_groups = []
|
||||
if operation_tags:
|
||||
tag_groups.append(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue