forked from phoenix-oss/llama-stack-mirror
openapi gen return type fix for streaming/non-streaming (#910)
# What does this PR do? We need to change ```yaml /v1/inference/chat-completion: post: responses: '200': description: >- If stream=False, returns a ChatCompletionResponse with the full completion. If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk content: text/event-stream: schema: oneOf: - $ref: '#/components/schemas/ChatCompletionResponse' - $ref: '#/components/schemas/ChatCompletionResponseStreamChunk' ``` into ```yaml /v1/inference/chat-completion: post: responses: '200': description: >- If stream=False, returns a ChatCompletionResponse with the full completion. If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk content: text/event-stream: schema: $ref: '#/components/schemas/ChatCompletionResponseStreamChunk' application/json: schema: $ref: '#/components/schemas/ChatCompletionResponse' ``` ## Test Plan **Python** - tested in SDK sync: https://github.com/meta-llama/llama-stack-client-python/pull/108 **Node** - tested w/ https://gist.github.com/yanxi0830/b782f4b91e21dcccdfef8898ce55157e (SDK udpate follow up) ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
2f11c7c203
commit
15dcc4ea5e
3 changed files with 433 additions and 419 deletions
|
@ -177,20 +177,37 @@ class ContentBuilder:
|
|||
) -> Dict[str, MediaType]:
|
||||
"Creates the content subtree for a request or response."
|
||||
|
||||
def has_iterator_type(t):
|
||||
if typing.get_origin(t) is typing.Union:
|
||||
return any(has_iterator_type(a) for a in typing.get_args(t))
|
||||
def is_iterator_type(t):
|
||||
return "StreamChunk" in str(t)
|
||||
|
||||
def get_media_type(t):
|
||||
if is_generic_list(t):
|
||||
return "application/jsonl"
|
||||
elif is_iterator_type(t):
|
||||
return "text/event-stream"
|
||||
else:
|
||||
# TODO: needs a proper fix where we let all types correctly flow upwards
|
||||
# and then test against AsyncIterator
|
||||
return "StreamChunk" in str(t)
|
||||
return "application/json"
|
||||
|
||||
if typing.get_origin(payload_type) is typing.Union:
|
||||
media_types = []
|
||||
item_types = []
|
||||
for x in typing.get_args(payload_type):
|
||||
media_types.append(get_media_type(x))
|
||||
item_types.append(x)
|
||||
|
||||
if len(set(media_types)) == 1:
|
||||
# all types have the same media type
|
||||
return {media_types[0]: self.build_media_type(payload_type, examples)}
|
||||
else:
|
||||
# different types have different media types
|
||||
return {
|
||||
media_type: self.build_media_type(item_type, examples)
|
||||
for media_type, item_type in zip(media_types, item_types)
|
||||
}
|
||||
|
||||
if is_generic_list(payload_type):
|
||||
media_type = "application/jsonl"
|
||||
item_type = unwrap_generic_list(payload_type)
|
||||
elif has_iterator_type(payload_type):
|
||||
item_type = payload_type
|
||||
media_type = "text/event-stream"
|
||||
else:
|
||||
media_type = "application/json"
|
||||
item_type = payload_type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue