From b94d8e9880f48a40d74bb0dadacff1aaff0dfaff Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 30 Jan 2025 16:53:32 -0800 Subject: [PATCH] fix gen streaming --- docs/openapi_generator/pyopenapi/generator.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index 202d3732b..5bd0cfe57 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -177,20 +177,45 @@ class ContentBuilder: ) -> Dict[str, MediaType]: "Creates the content subtree for a request or response." - def has_iterator_type(t): - if typing.get_origin(t) is typing.Union: - return any(has_iterator_type(a) for a in typing.get_args(t)) + # def has_iterator_type(t): + # if typing.get_origin(t) is typing.Union: + # return any(has_iterator_type(a) for a in typing.get_args(t)) + # else: + # # TODO: needs a proper fix where we let all types correctly flow upwards + # # and then test against AsyncIterator + # return "StreamChunk" in str(t) + + def is_iterator_type(t): + return "StreamChunk" in str(t) + + def get_media_type(t): + if is_generic_list(t): + return "application/jsonl" + elif is_iterator_type(t): + return "text/event-stream" else: - # TODO: needs a proper fix where we let all types correctly flow upwards - # and then test against AsyncIterator - return "StreamChunk" in str(t) + return "application/json" + + if typing.get_origin(payload_type) is typing.Union: + media_types = [] + item_types = [] + for x in typing.get_args(payload_type): + media_types.append(get_media_type(x)) + item_types.append(x) + + if len(set(media_types)) == 1: + # all types have the same media type + return {media_types[0]: self.build_media_type(payload_type, examples)} + else: + # different types have different media types + return { + media_type: self.build_media_type(item_type, examples) + for media_type, item_type in zip(media_types, item_types) + } if is_generic_list(payload_type): media_type = "application/jsonl" item_type = unwrap_generic_list(payload_type) - elif has_iterator_type(payload_type): - item_type = payload_type - media_type = "text/event-stream" else: media_type = "application/json" item_type = payload_type