Merge remote-tracking branch 'origin/main' into api_updates_1

This commit is contained in:
Ashwin Bharambe 2024-09-03 21:42:25 -07:00
commit 85d56ed3f2
4 changed files with 782 additions and 417 deletions

View file

@ -18,10 +18,10 @@ import yaml
from llama_models import schema_utils
# We do a series of monkey-patching to ensure our definitions only use the minimal
# We do some monkey-patching to ensure our definitions only use the minimal
# (json_schema_type, webmethod) definitions from the llama_models package. For
# generation though, we need the full definitions and implementations from the
# (python-openapi, json-strong-typing) packages.
# (json-strong-typing) package.
from strong_typing.schema import json_schema_type
@ -31,10 +31,20 @@ from .pyopenapi.utility import Specification
schema_utils.json_schema_type = json_schema_type
from llama_toolchain.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = ["/agentic_system/turn/create"]
def patch_sse_stream_responses(spec: Specification):
for path, path_item in spec.document.paths.items():
if path in STREAMING_ENDPOINTS:
content = path_item.post.responses["200"].content.pop("application/json")
path_item.post.responses["200"].content["text/event-stream"] = content
def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():
@ -60,6 +70,9 @@ def main(output_dir: str):
),
),
)
patch_sse_stream_responses(spec)
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)