mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
openapi gen return type fix for streaming/non-streaming (#910)
# What does this PR do?
We need to change
```yaml
/v1/inference/chat-completion:
post:
responses:
'200':
description: >-
If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
content:
text/event-stream:
schema:
oneOf:
- $ref: '#/components/schemas/ChatCompletionResponse'
- $ref: '#/components/schemas/ChatCompletionResponseStreamChunk'
```
into
```yaml
/v1/inference/chat-completion:
post:
responses:
'200':
description: >-
If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
content:
text/event-stream:
schema:
$ref: '#/components/schemas/ChatCompletionResponseStreamChunk'
application/json:
schema:
$ref: '#/components/schemas/ChatCompletionResponse'
```
## Test Plan
**Python**
- tested in SDK sync:
https://github.com/meta-llama/llama-stack-client-python/pull/108
**Node**
- tested w/
https://gist.github.com/yanxi0830/b782f4b91e21dcccdfef8898ce55157e (SDK
udpate follow up)
## Sources
Please link relevant resources if necessary.
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
2f11c7c203
commit
15dcc4ea5e
3 changed files with 433 additions and 419 deletions
|
|
@ -177,20 +177,37 @@ class ContentBuilder:
|
|||
) -> Dict[str, MediaType]:
|
||||
"Creates the content subtree for a request or response."
|
||||
|
||||
def has_iterator_type(t):
|
||||
if typing.get_origin(t) is typing.Union:
|
||||
return any(has_iterator_type(a) for a in typing.get_args(t))
|
||||
def is_iterator_type(t):
|
||||
return "StreamChunk" in str(t)
|
||||
|
||||
def get_media_type(t):
|
||||
if is_generic_list(t):
|
||||
return "application/jsonl"
|
||||
elif is_iterator_type(t):
|
||||
return "text/event-stream"
|
||||
else:
|
||||
# TODO: needs a proper fix where we let all types correctly flow upwards
|
||||
# and then test against AsyncIterator
|
||||
return "StreamChunk" in str(t)
|
||||
return "application/json"
|
||||
|
||||
if typing.get_origin(payload_type) is typing.Union:
|
||||
media_types = []
|
||||
item_types = []
|
||||
for x in typing.get_args(payload_type):
|
||||
media_types.append(get_media_type(x))
|
||||
item_types.append(x)
|
||||
|
||||
if len(set(media_types)) == 1:
|
||||
# all types have the same media type
|
||||
return {media_types[0]: self.build_media_type(payload_type, examples)}
|
||||
else:
|
||||
# different types have different media types
|
||||
return {
|
||||
media_type: self.build_media_type(item_type, examples)
|
||||
for media_type, item_type in zip(media_types, item_types)
|
||||
}
|
||||
|
||||
if is_generic_list(payload_type):
|
||||
media_type = "application/jsonl"
|
||||
item_type = unwrap_generic_list(payload_type)
|
||||
elif has_iterator_type(payload_type):
|
||||
item_type = payload_type
|
||||
media_type = "text/event-stream"
|
||||
else:
|
||||
media_type = "application/json"
|
||||
item_type = payload_type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue