chore: re-add text/event-stream media type

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-12 11:30:37 +01:00
parent 66056ddb87
commit e3d831f504
No known key found for this signature in database
4 changed files with 120 additions and 26 deletions

View file

@ -12,6 +12,8 @@ FastAPI-based OpenAPI generator for Llama Stack.
import importlib
import inspect
import pkgutil
import types
import typing
from pathlib import Path
from typing import Annotated, Any, get_args, get_origin
@ -237,7 +239,9 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
name = route.name
fastapi_path = path.replace("{", "{").replace("}", "}")
request_model, response_model, query_parameters, file_form_params = _find_models_for_endpoint(webmethod, api, name)
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
_find_models_for_endpoint(webmethod, api, name)
)
operation_description = _extract_operation_description_from_docstring(api, name)
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
@ -336,6 +340,32 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
no_params_endpoint.__doc__ = operation_description
endpoint_func = no_params_endpoint
# Build response content with both application/json and text/event-stream if streaming
response_content = {}
if response_model:
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
if streaming_response_model:
# Get the schema name for the streaming model
# It might be a registered schema or a Pydantic model
streaming_schema_name = None
# Check if it's a registered schema first (before checking __name__)
# because registered schemas might be Annotated types
from llama_stack.schema_utils import _registered_schemas
if streaming_response_model in _registered_schemas:
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
elif hasattr(streaming_response_model, "__name__"):
streaming_schema_name = streaming_response_model.__name__
if streaming_schema_name:
response_content["text/event-stream"] = {
"schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"}
}
# If no content types, use empty schema
if not response_content:
response_content["application/json"] = {"schema": {}}
# Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False
route_kwargs = {
@ -345,11 +375,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
"responses": {
200: {
"description": response_description,
"content": {
"application/json": {
"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {}
}
},
"content": response_content,
},
400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"},
@ -422,9 +448,59 @@ def _is_extra_body_field(metadata_item: Any) -> bool:
return isinstance(metadata_item, ExtraBodyField)
def _is_async_iterator_type(type_obj: Any) -> bool:
"""Check if a type is AsyncIterator or AsyncIterable."""
from collections.abc import AsyncIterable, AsyncIterator
origin = get_origin(type_obj)
if origin is None:
# Check if it's the class itself
return type_obj in (AsyncIterator, AsyncIterable) or (
hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable)
)
return origin in (AsyncIterator, AsyncIterable)
def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]:
"""
Extract non-streaming and streaming response models from a union type.
Returns:
tuple: (non_streaming_model, streaming_model)
"""
non_streaming_model = None
streaming_model = None
args = get_args(union_type)
for arg in args:
# Check if it's an AsyncIterator
if _is_async_iterator_type(arg):
# Extract the type argument from AsyncIterator[T]
iterator_args = get_args(arg)
if iterator_args:
inner_type = iterator_args[0]
# Check if the inner type is a registered schema (union type)
# or a Pydantic model
if hasattr(inner_type, "model_json_schema"):
streaming_model = inner_type
else:
# Might be a registered schema - check if it's registered
from llama_stack.schema_utils import _registered_schemas
if inner_type in _registered_schemas:
# We'll need to look this up later, but for now store the type
streaming_model = inner_type
elif hasattr(arg, "model_json_schema"):
# Non-streaming Pydantic model
if non_streaming_model is None:
non_streaming_model = arg
return non_streaming_model, streaming_model
def _find_models_for_endpoint(
webmethod, api: Api, method_name: str
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]:
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
"""
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the protocol function to determine the correct models dynamically.
@ -435,15 +511,16 @@ def _find_models_for_endpoint(
method_name: The method name (function name)
Returns:
tuple: (request_model, response_model, query_parameters, file_form_params)
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
where query_parameters is a list of (name, type, default_value) tuples
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
and streaming_response_model is the model for streaming responses (AsyncIterator content)
"""
try:
# Get the function from the protocol
func = _get_protocol_method(api, method_name)
if not func:
return None, None, [], []
return None, None, [], [], None
# Analyze the function signature
sig = inspect.signature(func)
@ -542,38 +619,37 @@ def _find_models_for_endpoint(
query_parameters = [] # Clear query_parameters so we use the single model
# Find response model from return annotation
# Also detect streaming response models (AsyncIterator)
response_model = None
streaming_response_model = None
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
origin = get_origin(return_annotation)
if hasattr(return_annotation, "model_json_schema"):
response_model = return_annotation
elif get_origin(return_annotation) is Annotated:
elif origin is Annotated:
# Handle Annotated return types
args = get_args(return_annotation)
if args:
# Check if the first argument is a Pydantic model
if hasattr(args[0], "model_json_schema"):
response_model = args[0]
# Check if the first argument is a union type
elif get_origin(args[0]) is type(args[0]): # Union type
union_args = get_args(args[0])
for arg in union_args:
if hasattr(arg, "model_json_schema"):
response_model = arg
break
elif get_origin(return_annotation) is type(return_annotation): # Union type
# Handle union types - try to find the first Pydantic model
args = get_args(return_annotation)
for arg in args:
if hasattr(arg, "model_json_schema"):
response_model = arg
break
else:
# Check if the first argument is a union type
inner_origin = get_origin(args[0])
if inner_origin is not None and (
inner_origin is types.UnionType or inner_origin is typing.Union
):
response_model, streaming_response_model = _extract_response_models_from_union(args[0])
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
# Handle union types - extract both non-streaming and streaming models
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
return request_model, response_model, query_parameters, file_form_params
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
except Exception:
# If we can't analyze the function signature, return None
return None, None, [], []
return None, None, [], [], None
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None: