chore: re-add text/event-stream media type

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-12 11:30:37 +01:00
parent 66056ddb87
commit e3d831f504
No known key found for this signature in database
4 changed files with 120 additions and 26 deletions

View file

@ -1513,6 +1513,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIResponseObject' $ref: '#/components/schemas/OpenAIResponseObject'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIResponseObjectStream'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request
@ -1866,6 +1869,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIChatCompletion' $ref: '#/components/schemas/OpenAIChatCompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request

View file

@ -121,6 +121,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIResponseObject' $ref: '#/components/schemas/OpenAIResponseObject'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIResponseObjectStream'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request
@ -454,6 +457,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIChatCompletion' $ref: '#/components/schemas/OpenAIChatCompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request

View file

@ -1513,6 +1513,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIResponseObject' $ref: '#/components/schemas/OpenAIResponseObject'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIResponseObjectStream'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request
@ -1864,6 +1867,9 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIChatCompletion' $ref: '#/components/schemas/OpenAIChatCompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
description: Bad Request description: Bad Request

View file

@ -12,6 +12,8 @@ FastAPI-based OpenAPI generator for Llama Stack.
import importlib import importlib
import inspect import inspect
import pkgutil import pkgutil
import types
import typing
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, get_args, get_origin from typing import Annotated, Any, get_args, get_origin
@ -237,7 +239,9 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
name = route.name name = route.name
fastapi_path = path.replace("{", "{").replace("}", "}") fastapi_path = path.replace("{", "{").replace("}", "}")
request_model, response_model, query_parameters, file_form_params = _find_models_for_endpoint(webmethod, api, name) request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
_find_models_for_endpoint(webmethod, api, name)
)
operation_description = _extract_operation_description_from_docstring(api, name) operation_description = _extract_operation_description_from_docstring(api, name)
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name) response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods) is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
@ -336,6 +340,32 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
no_params_endpoint.__doc__ = operation_description no_params_endpoint.__doc__ = operation_description
endpoint_func = no_params_endpoint endpoint_func = no_params_endpoint
# Build response content with both application/json and text/event-stream if streaming
response_content = {}
if response_model:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
if streaming_response_model:
# Get the schema name for the streaming model
# It might be a registered schema or a Pydantic model
streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types
from llama_stack.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
elif hasattr(streaming_response_model, "__name__"):
streaming_schema_name = streaming_response_model.__name__
if streaming_schema_name:
response_content["text/event-stream"] = {
"schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"}
}
# If no content types, use empty schema
if not response_content:
response_content["application/json"] = {"schema": {}}
# Add the endpoint to the FastAPI app # Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False is_deprecated = webmethod.deprecated or False
route_kwargs = { route_kwargs = {
@ -345,11 +375,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"responses": { "responses": {
200: { 200: {
"description": response_description, "description": response_description,
"content": { "content": response_content,
"application/json": {
"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {}
}
},
}, },
400: {"$ref": "#/components/responses/BadRequest400"}, 400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"}, 429: {"$ref": "#/components/responses/TooManyRequests429"},
@ -422,9 +448,59 @@ def _is_extra_body_field(metadata_item: Any) -> bool:
return isinstance(metadata_item, ExtraBodyField) return isinstance(metadata_item, ExtraBodyField)
def _is_async_iterator_type(type_obj: Any) -> bool:
"""Check if a type is AsyncIterator or AsyncIterable."""
from collections.abc import AsyncIterable, AsyncIterator
origin = get_origin(type_obj)
if origin is None:
# Check if it's the class itself
return type_obj in (AsyncIterator, AsyncIterable) or (
hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable)
)
return origin in (AsyncIterator, AsyncIterable)
def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]:
"""
Extract non-streaming and streaming response models from a union type.
Returns:
tuple: (non_streaming_model, streaming_model)
"""
non_streaming_model = None
streaming_model = None
args = get_args(union_type)
for arg in args:
# Check if it's an AsyncIterator
if _is_async_iterator_type(arg):
# Extract the type argument from AsyncIterator[T]
iterator_args = get_args(arg)
if iterator_args:
inner_type = iterator_args[0]
# Check if the inner type is a registered schema (union type)
# or a Pydantic model
if hasattr(inner_type, "model_json_schema"):
streaming_model = inner_type
else:
# Might be a registered schema - check if it's registered
from llama_stack.schema_utils import _registered_schemas
if inner_type in _registered_schemas:
# We'll need to look this up later, but for now store the type
streaming_model = inner_type
elif hasattr(arg, "model_json_schema"):
# Non-streaming Pydantic model
if non_streaming_model is None:
non_streaming_model = arg
return non_streaming_model, streaming_model
def _find_models_for_endpoint( def _find_models_for_endpoint(
webmethod, api: Api, method_name: str webmethod, api: Api, method_name: str
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]: ) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
""" """
Find appropriate request and response models for an endpoint by analyzing the actual function signature. Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the protocol function to determine the correct models dynamically. This uses the protocol function to determine the correct models dynamically.
@ -435,15 +511,16 @@ def _find_models_for_endpoint(
method_name: The method name (function name) method_name: The method name (function name)
Returns: Returns:
tuple: (request_model, response_model, query_parameters, file_form_params) tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
where query_parameters is a list of (name, type, default_value) tuples where query_parameters is a list of (name, type, default_value) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content)
""" """
try: try:
# Get the function from the protocol # Get the function from the protocol
func = _get_protocol_method(api, method_name) func = _get_protocol_method(api, method_name)
if not func: if not func:
return None, None, [], [] return None, None, [], [], None
# Analyze the function signature # Analyze the function signature
sig = inspect.signature(func) sig = inspect.signature(func)
@ -542,38 +619,37 @@ def _find_models_for_endpoint(
query_parameters = [] # Clear query_parameters so we use the single model query_parameters = [] # Clear query_parameters so we use the single model
# Find response model from return annotation # Find response model from return annotation
# Also detect streaming response models (AsyncIterator)
response_model = None response_model = None
streaming_response_model = None
return_annotation = sig.return_annotation return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty: if return_annotation != inspect.Signature.empty:
origin = get_origin(return_annotation)
if hasattr(return_annotation, "model_json_schema"): if hasattr(return_annotation, "model_json_schema"):
response_model = return_annotation response_model = return_annotation
elif get_origin(return_annotation) is Annotated: elif origin is Annotated:
# Handle Annotated return types # Handle Annotated return types
args = get_args(return_annotation) args = get_args(return_annotation)
if args: if args:
# Check if the first argument is a Pydantic model # Check if the first argument is a Pydantic model
if hasattr(args[0], "model_json_schema"): if hasattr(args[0], "model_json_schema"):
response_model = args[0] response_model = args[0]
# Check if the first argument is a union type else:
elif get_origin(args[0]) is type(args[0]): # Union type # Check if the first argument is a union type
union_args = get_args(args[0]) inner_origin = get_origin(args[0])
for arg in union_args: if inner_origin is not None and (
if hasattr(arg, "model_json_schema"): inner_origin is types.UnionType or inner_origin is typing.Union
response_model = arg ):
break response_model, streaming_response_model = _extract_response_models_from_union(args[0])
elif get_origin(return_annotation) is type(return_annotation): # Union type elif origin is not None and (origin is types.UnionType or origin is typing.Union):
# Handle union types - try to find the first Pydantic model # Handle union types - extract both non-streaming and streaming models
args = get_args(return_annotation) response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
for arg in args:
if hasattr(arg, "model_json_schema"):
response_model = arg
break
return request_model, response_model, query_parameters, file_form_params return request_model, response_model, query_parameters, file_form_params, streaming_response_model
except Exception: except Exception:
# If we can't analyze the function signature, return None # If we can't analyze the function signature, return None
return None, None, [], [] return None, None, [], [], None
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None: def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None: