diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py index 64f2c8465..4b13904de 100644 --- a/rfcs/openapi_generator/generate.py +++ b/rfcs/openapi_generator/generate.py @@ -36,6 +36,11 @@ schema_utils.json_schema_type = json_schema_type from llama_toolchain.stack import LlamaStack +STREAMING_ENDPOINTS = [ + "/agentic_system/turn/create" +] + + def patched_get_endpoint_functions( endpoint: type, prefixes: List[str] ) -> Iterator[Tuple[str, str, str, Callable]]: @@ -75,6 +80,13 @@ def patched_get_endpoint_functions( operations._get_endpoint_functions = patched_get_endpoint_functions +def patch_sse_stream_responses(spec: Specification): + for path, path_item in spec.document.paths.items(): + if path in STREAMING_ENDPOINTS: + content = path_item.post.responses['200'].content.pop('application/json') + path_item.post.responses['200'].content['text/event-stream'] = content + + def main(output_dir: str): output_dir = Path(output_dir) if not output_dir.exists(): @@ -100,6 +112,9 @@ def main(output_dir: str): ), ), ) + + patch_sse_stream_responses(spec) + with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp: yaml.dump(spec.get_json(), fp, allow_unicode=True) diff --git a/rfcs/openapi_generator/run_openapi_generator.sh b/rfcs/openapi_generator/run_openapi_generator.sh old mode 100644 new mode 100755