From 35093c0b6f2f19a8c9e380bb014244f31600b7e7 Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Tue, 3 Sep 2024 23:40:31 -0400 Subject: [PATCH] Add patch for SSE event endpoint responses (#50) --- rfcs/openapi_generator/generate.py | 15 +++++++++++++++ rfcs/openapi_generator/run_openapi_generator.sh | 0 2 files changed, 15 insertions(+) mode change 100644 => 100755 rfcs/openapi_generator/run_openapi_generator.sh diff --git a/rfcs/openapi_generator/generate.py b/rfcs/openapi_generator/generate.py index 64f2c8465..4b13904de 100644 --- a/rfcs/openapi_generator/generate.py +++ b/rfcs/openapi_generator/generate.py @@ -36,6 +36,11 @@ schema_utils.json_schema_type = json_schema_type from llama_toolchain.stack import LlamaStack +STREAMING_ENDPOINTS = [ + "/agentic_system/turn/create" +] + + def patched_get_endpoint_functions( endpoint: type, prefixes: List[str] ) -> Iterator[Tuple[str, str, str, Callable]]: @@ -75,6 +80,13 @@ def patched_get_endpoint_functions( operations._get_endpoint_functions = patched_get_endpoint_functions +def patch_sse_stream_responses(spec: Specification): + for path, path_item in spec.document.paths.items(): + if path in STREAMING_ENDPOINTS: + content = path_item.post.responses['200'].content.pop('application/json') + path_item.post.responses['200'].content['text/event-stream'] = content + + def main(output_dir: str): output_dir = Path(output_dir) if not output_dir.exists(): @@ -100,6 +112,9 @@ def main(output_dir: str): ), ), ) + + patch_sse_stream_responses(spec) + with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp: yaml.dump(spec.get_json(), fp, allow_unicode=True) diff --git a/rfcs/openapi_generator/run_openapi_generator.sh b/rfcs/openapi_generator/run_openapi_generator.sh old mode 100644 new mode 100755