fix gen streaming

This commit is contained in:
Xi Yan 2025-01-30 16:53:32 -08:00
parent 7fe2592795
commit b94d8e9880

View file

@ -177,20 +177,45 @@ class ContentBuilder:
) -> Dict[str, MediaType]:
"Creates the content subtree for a request or response."
def has_iterator_type(t):
if typing.get_origin(t) is typing.Union:
return any(has_iterator_type(a) for a in typing.get_args(t))
# def has_iterator_type(t):
# if typing.get_origin(t) is typing.Union:
# return any(has_iterator_type(a) for a in typing.get_args(t))
# else:
# # TODO: needs a proper fix where we let all types correctly flow upwards
# # and then test against AsyncIterator
# return "StreamChunk" in str(t)
def is_iterator_type(t):
return "StreamChunk" in str(t)
def get_media_type(t):
if is_generic_list(t):
return "application/jsonl"
elif is_iterator_type(t):
return "text/event-stream"
else:
# TODO: needs a proper fix where we let all types correctly flow upwards
# and then test against AsyncIterator
return "StreamChunk" in str(t)
return "application/json"
if typing.get_origin(payload_type) is typing.Union:
media_types = []
item_types = []
for x in typing.get_args(payload_type):
media_types.append(get_media_type(x))
item_types.append(x)
if len(set(media_types)) == 1:
# all types have the same media type
return {media_types[0]: self.build_media_type(payload_type, examples)}
else:
# different types have different media types
return {
media_type: self.build_media_type(item_type, examples)
for media_type, item_type in zip(media_types, item_types)
}
if is_generic_list(payload_type):
media_type = "application/jsonl"
item_type = unwrap_generic_list(payload_type)
elif has_iterator_type(payload_type):
item_type = payload_type
media_type = "text/event-stream"
else:
media_type = "application/json"
item_type = payload_type