mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
fix gen streaming
This commit is contained in:
parent
7fe2592795
commit
b94d8e9880
1 changed files with 34 additions and 9 deletions
|
@ -177,20 +177,45 @@ class ContentBuilder:
|
|||
) -> Dict[str, MediaType]:
|
||||
"Creates the content subtree for a request or response."
|
||||
|
||||
def has_iterator_type(t):
|
||||
if typing.get_origin(t) is typing.Union:
|
||||
return any(has_iterator_type(a) for a in typing.get_args(t))
|
||||
# def has_iterator_type(t):
|
||||
# if typing.get_origin(t) is typing.Union:
|
||||
# return any(has_iterator_type(a) for a in typing.get_args(t))
|
||||
# else:
|
||||
# # TODO: needs a proper fix where we let all types correctly flow upwards
|
||||
# # and then test against AsyncIterator
|
||||
# return "StreamChunk" in str(t)
|
||||
|
||||
def is_iterator_type(t):
|
||||
return "StreamChunk" in str(t)
|
||||
|
||||
def get_media_type(t):
|
||||
if is_generic_list(t):
|
||||
return "application/jsonl"
|
||||
elif is_iterator_type(t):
|
||||
return "text/event-stream"
|
||||
else:
|
||||
# TODO: needs a proper fix where we let all types correctly flow upwards
|
||||
# and then test against AsyncIterator
|
||||
return "StreamChunk" in str(t)
|
||||
return "application/json"
|
||||
|
||||
if typing.get_origin(payload_type) is typing.Union:
|
||||
media_types = []
|
||||
item_types = []
|
||||
for x in typing.get_args(payload_type):
|
||||
media_types.append(get_media_type(x))
|
||||
item_types.append(x)
|
||||
|
||||
if len(set(media_types)) == 1:
|
||||
# all types have the same media type
|
||||
return {media_types[0]: self.build_media_type(payload_type, examples)}
|
||||
else:
|
||||
# different types have different media types
|
||||
return {
|
||||
media_type: self.build_media_type(item_type, examples)
|
||||
for media_type, item_type in zip(media_types, item_types)
|
||||
}
|
||||
|
||||
if is_generic_list(payload_type):
|
||||
media_type = "application/jsonl"
|
||||
item_type = unwrap_generic_list(payload_type)
|
||||
elif has_iterator_type(payload_type):
|
||||
item_type = payload_type
|
||||
media_type = "text/event-stream"
|
||||
else:
|
||||
media_type = "application/json"
|
||||
item_type = payload_type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue