diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index cb7ff276d..a5900c18e 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -1513,6 +1513,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIResponseObject' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIResponseObjectStream' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request @@ -1866,6 +1869,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIChatCompletion' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIChatCompletionChunk' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 7370ed9fe..076fa42ae 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -121,6 +121,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIResponseObject' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIResponseObjectStream' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request @@ -454,6 +457,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIChatCompletion' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIChatCompletionChunk' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 5302945fb..3436a10dc 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -1513,6 +1513,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIResponseObject' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIResponseObjectStream' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request @@ -1864,6 +1867,9 @@ paths: application/json: schema: $ref: '#/components/schemas/OpenAIChatCompletion' + text/event-stream: + schema: + $ref: '#/components/schemas/OpenAIChatCompletionChunk' '400': $ref: '#/components/responses/BadRequest400' description: Bad Request diff --git a/scripts/fastapi_generator.py b/scripts/fastapi_generator.py index 396a8cb65..7c3641661 100755 --- a/scripts/fastapi_generator.py +++ b/scripts/fastapi_generator.py @@ -12,6 +12,8 @@ FastAPI-based OpenAPI generator for Llama Stack. import importlib import inspect import pkgutil +import types +import typing from pathlib import Path from typing import Annotated, Any, get_args, get_origin @@ -237,7 +239,9 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api): name = route.name fastapi_path = path.replace("{", "{").replace("}", "}") - request_model, response_model, query_parameters, file_form_params = _find_models_for_endpoint(webmethod, api, name) + request_model, response_model, query_parameters, file_form_params, streaming_response_model = ( + _find_models_for_endpoint(webmethod, api, name) + ) operation_description = _extract_operation_description_from_docstring(api, name) response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name) is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods) @@ -336,6 +340,32 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api): no_params_endpoint.__doc__ = operation_description endpoint_func = no_params_endpoint + # Build response content with both application/json and text/event-stream if streaming + response_content = {} + if response_model: + response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}} + if streaming_response_model: + # Get the schema name for the streaming model + # It might be a registered schema or a Pydantic model + streaming_schema_name = None + # Check if it's a registered schema first (before checking __name__) + # because registered schemas might be Annotated types + from llama_stack.schema_utils import _registered_schemas + + if streaming_response_model in _registered_schemas: + streaming_schema_name = _registered_schemas[streaming_response_model]["name"] + elif hasattr(streaming_response_model, "__name__"): + streaming_schema_name = streaming_response_model.__name__ + + if streaming_schema_name: + response_content["text/event-stream"] = { + "schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"} + } + + # If no content types, use empty schema + if not response_content: + response_content["application/json"] = {"schema": {}} + # Add the endpoint to the FastAPI app is_deprecated = webmethod.deprecated or False route_kwargs = { @@ -345,11 +375,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api): "responses": { 200: { "description": response_description, - "content": { - "application/json": { - "schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {} - } - }, + "content": response_content, }, 400: {"$ref": "#/components/responses/BadRequest400"}, 429: {"$ref": "#/components/responses/TooManyRequests429"}, @@ -422,9 +448,59 @@ def _is_extra_body_field(metadata_item: Any) -> bool: return isinstance(metadata_item, ExtraBodyField) +def _is_async_iterator_type(type_obj: Any) -> bool: + """Check if a type is AsyncIterator or AsyncIterable.""" + from collections.abc import AsyncIterable, AsyncIterator + + origin = get_origin(type_obj) + if origin is None: + # Check if it's the class itself + return type_obj in (AsyncIterator, AsyncIterable) or ( + hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable) + ) + return origin in (AsyncIterator, AsyncIterable) + + +def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]: + """ + Extract non-streaming and streaming response models from a union type. + + Returns: + tuple: (non_streaming_model, streaming_model) + """ + non_streaming_model = None + streaming_model = None + + args = get_args(union_type) + for arg in args: + # Check if it's an AsyncIterator + if _is_async_iterator_type(arg): + # Extract the type argument from AsyncIterator[T] + iterator_args = get_args(arg) + if iterator_args: + inner_type = iterator_args[0] + # Check if the inner type is a registered schema (union type) + # or a Pydantic model + if hasattr(inner_type, "model_json_schema"): + streaming_model = inner_type + else: + # Might be a registered schema - check if it's registered + from llama_stack.schema_utils import _registered_schemas + + if inner_type in _registered_schemas: + # We'll need to look this up later, but for now store the type + streaming_model = inner_type + elif hasattr(arg, "model_json_schema"): + # Non-streaming Pydantic model + if non_streaming_model is None: + non_streaming_model = arg + + return non_streaming_model, streaming_model + + def _find_models_for_endpoint( webmethod, api: Api, method_name: str -) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]: +) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]: """ Find appropriate request and response models for an endpoint by analyzing the actual function signature. This uses the protocol function to determine the correct models dynamically. @@ -435,15 +511,16 @@ def _find_models_for_endpoint( method_name: The method name (function name) Returns: - tuple: (request_model, response_model, query_parameters, file_form_params) + tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model) where query_parameters is a list of (name, type, default_value) tuples and file_form_params is a list of inspect.Parameter objects for File()/Form() params + and streaming_response_model is the model for streaming responses (AsyncIterator content) """ try: # Get the function from the protocol func = _get_protocol_method(api, method_name) if not func: - return None, None, [], [] + return None, None, [], [], None # Analyze the function signature sig = inspect.signature(func) @@ -542,38 +619,37 @@ def _find_models_for_endpoint( query_parameters = [] # Clear query_parameters so we use the single model # Find response model from return annotation + # Also detect streaming response models (AsyncIterator) response_model = None + streaming_response_model = None return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: + origin = get_origin(return_annotation) if hasattr(return_annotation, "model_json_schema"): response_model = return_annotation - elif get_origin(return_annotation) is Annotated: + elif origin is Annotated: # Handle Annotated return types args = get_args(return_annotation) if args: # Check if the first argument is a Pydantic model if hasattr(args[0], "model_json_schema"): response_model = args[0] - # Check if the first argument is a union type - elif get_origin(args[0]) is type(args[0]): # Union type - union_args = get_args(args[0]) - for arg in union_args: - if hasattr(arg, "model_json_schema"): - response_model = arg - break - elif get_origin(return_annotation) is type(return_annotation): # Union type - # Handle union types - try to find the first Pydantic model - args = get_args(return_annotation) - for arg in args: - if hasattr(arg, "model_json_schema"): - response_model = arg - break + else: + # Check if the first argument is a union type + inner_origin = get_origin(args[0]) + if inner_origin is not None and ( + inner_origin is types.UnionType or inner_origin is typing.Union + ): + response_model, streaming_response_model = _extract_response_models_from_union(args[0]) + elif origin is not None and (origin is types.UnionType or origin is typing.Union): + # Handle union types - extract both non-streaming and streaming models + response_model, streaming_response_model = _extract_response_models_from_union(return_annotation) - return request_model, response_model, query_parameters, file_form_params + return request_model, response_model, query_parameters, file_form_params, streaming_response_model except Exception: # If we can't analyze the function signature, return None - return None, None, [], [] + return None, None, [], [], None def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None: