#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. """ FastAPI-based OpenAPI generator for Llama Stack. """ import inspect import json from pathlib import Path from typing import Annotated, Any, Literal, get_args, get_origin import yaml from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from openapi_spec_validator import validate_spec from openapi_spec_validator.exceptions import OpenAPISpecValidatorError from llama_stack.apis.datatypes import Api from llama_stack.core.resolver import api_protocol_map # Import the existing route discovery system from llama_stack.core.server.routes import get_all_api_routes # Global list to store dynamic models created during endpoint generation _dynamic_models = [] # Global mapping from (path, method) to webmethod for parameter description extraction _path_webmethod_map: dict[tuple[str, str], Any] = {} def _get_all_api_routes_with_functions(): """ Get all API routes with their actual function references. This is a modified version of get_all_api_routes that includes the function. """ from aiohttp import hdrs from starlette.routing import Route from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup apis = {} protocols = api_protocol_map() toolgroup_protocols = { SpecialToolGroup.rag_tool: RAGToolRuntime, } for api, protocol in protocols.items(): routes = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) # HACK ALERT if api == Api.tool_runtime: for tool_group in SpecialToolGroup: sub_protocol = toolgroup_protocols[tool_group] sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction) for name, method in sub_protocol_methods: if not hasattr(method, "__webmethod__"): continue protocol_methods.append((f"{tool_group.value}.{name}", method)) for name, method in protocol_methods: # Get all webmethods for this method (supports multiple decorators) webmethods = getattr(method, "__webmethods__", []) if not webmethods: continue # Create routes for each webmethod decorator for webmethod in webmethods: path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}" if webmethod.method == hdrs.METH_GET: http_method = hdrs.METH_GET elif webmethod.method == hdrs.METH_DELETE: http_method = hdrs.METH_DELETE else: http_method = hdrs.METH_POST # Store the function reference in the webmethod webmethod.func = method routes.append((Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)) apis[api] = routes return apis def create_llama_stack_app() -> FastAPI: """ Create a FastAPI app that represents the Llama Stack API. This uses the existing route discovery system to automatically find all routes. """ app = FastAPI( title="Llama Stack API", description="A comprehensive API for building and deploying AI applications", version="1.0.0", servers=[ {"url": "https://api.llamastack.com", "description": "Production server"}, {"url": "https://staging-api.llamastack.com", "description": "Staging server"}, ], ) # Get all API routes using the modified system that includes functions api_routes = _get_all_api_routes_with_functions() # Create FastAPI routes from the discovered routes for _, routes in api_routes.items(): for route, webmethod in routes: # Store mapping for later use in parameter description extraction for method in route.methods: _path_webmethod_map[(route.path, method.lower())] = webmethod # Convert the route to a FastAPI endpoint _create_fastapi_endpoint(app, route, webmethod) return app def _extract_path_parameters(path: str, webmethod=None) -> list[dict[str, Any]]: """ Extract path parameters from a URL path and return them as OpenAPI parameter definitions. Parameters are returned in the order they appear in the docstring if available, otherwise in the order they appear in the path. Args: path: URL path with parameters like /v1/batches/{batch_id}/cancel webmethod: Optional webmethod to extract parameter descriptions from docstring Returns: List of parameter definitions for OpenAPI """ import re # Find all path parameters in the format {param} or {param:type} param_pattern = r"\{([^}:]+)(?::[^}]+)?\}" path_params = set(re.findall(param_pattern, path)) # Extract parameter descriptions and order from docstring if available param_descriptions = {} docstring_param_order = [] if webmethod: func = getattr(webmethod, "func", None) if func and func.__doc__: docstring = func.__doc__ lines = docstring.split("\n") for line in lines: line = line.strip() if line.startswith(":param "): # Extract parameter name and description # Format: :param param_name: description parts = line[7:].split(":", 1) if len(parts) == 2: param_name = parts[0].strip() description = parts[1].strip() # Only track path parameters that exist in the path if param_name in path_params: if description: param_descriptions[param_name] = description if param_name not in docstring_param_order: docstring_param_order.append(param_name) # Build parameters list preserving docstring order for path parameters found in docstring, # then add any remaining path parameters in path order parameters = [] # First add parameters in docstring order for param_name in docstring_param_order: if param_name in path_params: description = param_descriptions.get(param_name, f"Path parameter: {param_name}") parameters.append( { "name": param_name, "in": "path", "required": True, "schema": {"type": "string"}, "description": description, } ) # Then add any path parameters not in docstring, in path order path_param_list = re.findall(param_pattern, path) for param_name in path_param_list: if param_name not in docstring_param_order: description = param_descriptions.get(param_name, f"Path parameter: {param_name}") parameters.append( { "name": param_name, "in": "path", "required": True, "schema": {"type": "string"}, "description": description, } ) return parameters def _create_fastapi_endpoint(app: FastAPI, route, webmethod): """ Create a FastAPI endpoint from a discovered route and webmethod. This creates endpoints with actual Pydantic models for proper schema generation. """ # Extract route information path = route.path methods = route.methods name = route.name # Convert path parameters from {param} to {param:path} format for FastAPI fastapi_path = path.replace("{", "{").replace("}", "}") # Try to find actual models for this endpoint request_model, response_model, query_parameters = _find_models_for_endpoint(webmethod) # Debug: Print info for safety endpoints if "safety" in webmethod.route or "shield" in webmethod.route: print( f"Debug: {webmethod.route} - request_model: {request_model}, response_model: {response_model}, query_parameters: {query_parameters}" ) # Extract summary and response description from webmethod docstring summary = _extract_summary_from_docstring(webmethod) response_description = _extract_response_description_from_docstring(webmethod, response_model) # Create endpoint function with proper typing if request_model and response_model: # POST/PUT request with request body async def typed_endpoint(request: request_model) -> response_model: """Typed endpoint for proper schema generation.""" return response_model() endpoint_func = typed_endpoint elif response_model and query_parameters: # Check if this is a POST/PUT endpoint with individual parameters # For POST/PUT, individual parameters should go in request body, not query params is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods) if is_post_put: # POST/PUT with individual parameters - create a request body model try: from pydantic import create_model # Create a dynamic Pydantic model for the request body field_definitions = {} for param_name, param_type, default_value in query_parameters: # Handle complex types that might cause issues with create_model safe_type = _make_type_safe_for_fastapi(param_type) if default_value is None: field_definitions[param_name] = (safe_type, ...) # Required field else: field_definitions[param_name] = (safe_type, default_value) # Optional field with default # Create the request model dynamically # Clean up the route name to create a valid schema name clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_") model_name = f"{clean_route}_Request" print(f"Debug: Creating model {model_name} with fields: {field_definitions}") request_model = create_model(model_name, **field_definitions) print(f"Debug: Successfully created model {model_name}") # Store the dynamic model in the global list for schema inclusion _dynamic_models.append(request_model) # Create endpoint with request body async def typed_endpoint(request: request_model) -> response_model: """Typed endpoint for proper schema generation.""" return response_model() # Set the function signature to ensure FastAPI recognizes the request model typed_endpoint.__annotations__ = {"request": request_model, "return": response_model} endpoint_func = typed_endpoint except Exception as e: # If dynamic model creation fails, fall back to query parameters print(f"Warning: Failed to create dynamic request model for {webmethod.route}: {e}") print(f" Query parameters: {query_parameters}") # Fall through to the query parameter handling pass if not is_post_put: # GET with query parameters - create a function with the actual query parameters def create_query_endpoint_func(): # Build the function signature dynamically import inspect # Create parameter annotations param_annotations = {} param_defaults = {} for param_name, param_type, default_value in query_parameters: # Handle problematic type annotations that cause FastAPI issues safe_type = _make_type_safe_for_fastapi(param_type) param_annotations[param_name] = safe_type if default_value is not None: param_defaults[param_name] = default_value # Create the function with the correct signature def create_endpoint_func(): # Sort parameters so that required parameters come before optional ones # Parameters with None default are required, others are optional sorted_params = sorted( query_parameters, key=lambda x: (x[2] is not None, x[0]), # False (required) comes before True (optional) ) # Create the function signature sig = inspect.Signature( [ inspect.Parameter( name=param_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=default_value if default_value is not None else inspect.Parameter.empty, annotation=param_annotations[param_name], ) for param_name, param_type, default_value in sorted_params ] ) # Create a simple function without **kwargs async def query_endpoint(): """Query endpoint for proper schema generation.""" return response_model() # Set the signature and annotations query_endpoint.__signature__ = sig query_endpoint.__annotations__ = param_annotations return query_endpoint return create_endpoint_func() endpoint_func = create_query_endpoint_func() elif response_model: # Response-only endpoint (no parameters) async def response_only_endpoint() -> response_model: """Response-only endpoint for proper schema generation.""" return response_model() endpoint_func = response_only_endpoint else: # Fallback to generic endpoint async def generic_endpoint(*args, **kwargs): """Generic endpoint - this would be replaced with actual implementation.""" return {"message": f"Endpoint {name} not implemented in OpenAPI generator"} endpoint_func = generic_endpoint # Add the endpoint to the FastAPI app is_deprecated = webmethod.deprecated or False route_kwargs = { "name": name, "tags": [_get_tag_from_api(webmethod)], "deprecated": is_deprecated, "responses": { 200: { "description": response_description, "content": { "application/json": { "schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {} } }, }, 400: {"$ref": "#/components/responses/BadRequest400"}, 429: {"$ref": "#/components/responses/TooManyRequests429"}, 500: {"$ref": "#/components/responses/InternalServerError500"}, "default": {"$ref": "#/components/responses/DefaultError"}, }, } if summary: route_kwargs["summary"] = summary for method in methods: if method.upper() == "GET": app.get(fastapi_path, **route_kwargs)(endpoint_func) elif method.upper() == "POST": app.post(fastapi_path, **route_kwargs)(endpoint_func) elif method.upper() == "PUT": app.put(fastapi_path, **route_kwargs)(endpoint_func) elif method.upper() == "DELETE": app.delete(fastapi_path, **route_kwargs)(endpoint_func) elif method.upper() == "PATCH": app.patch(fastapi_path, **route_kwargs)(endpoint_func) def _extract_summary_from_docstring(webmethod) -> str | None: """ Extract summary from the actual function docstring. The summary is typically the first non-empty line of the docstring, before any :param:, :returns:, or other docstring field markers. """ func = getattr(webmethod, "func", None) if not func: return None docstring = func.__doc__ or "" if not docstring: return None lines = docstring.split("\n") for line in lines: line = line.strip() if not line: continue if line.startswith(":param:") or line.startswith(":returns:") or line.startswith(":raises:"): break return line return None def _extract_response_description_from_docstring(webmethod, response_model) -> str: """ Extract response description from the actual function docstring. Looks for :returns: in the docstring and uses that as the description. """ func = getattr(webmethod, "func", None) if not func: return "Successful Response" docstring = func.__doc__ or "" lines = docstring.split("\n") for line in lines: line = line.strip() if line.startswith(":returns:"): description = line[9:].strip() if description: return description return "Successful Response" def _get_tag_from_api(webmethod) -> str: """Extract a tag name from the webmethod for API grouping.""" # Extract API name from the route path if webmethod.level: return webmethod.level.replace("/", "").title() return "API" def _find_models_for_endpoint(webmethod) -> tuple[type | None, type | None, list[tuple[str, type, Any]]]: """ Find appropriate request and response models for an endpoint by analyzing the actual function signature. This uses the webmethod's function to determine the correct models dynamically. Returns: tuple: (request_model, response_model, query_parameters) where query_parameters is a list of (name, type, default_value) tuples """ try: # Get the actual function from the webmethod func = getattr(webmethod, "func", None) if not func: return None, None, [] # Analyze the function signature sig = inspect.signature(func) # Find request model (first parameter that's not 'self') request_model = None query_parameters = [] for param_name, param in sig.parameters.items(): if param_name == "self": continue # Skip *args and **kwargs parameters - these are not real API parameters if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): continue # Check if it's a Pydantic model (for POST/PUT requests) param_type = param.annotation if hasattr(param_type, "model_json_schema"): request_model = param_type break elif get_origin(param_type) is Annotated: # Handle Annotated types - get the base type args = get_args(param_type) if args and hasattr(args[0], "model_json_schema"): request_model = args[0] break else: # This is likely a query parameter for GET requests # Store the parameter info for later use default_value = param.default if param.default != inspect.Parameter.empty else None # Extract the base type from union types (e.g., str | None -> str) # Also make it safe for FastAPI to avoid forward reference issues base_type = _make_type_safe_for_fastapi(param_type) query_parameters.append((param_name, base_type, default_value)) # Find response model from return annotation response_model = None return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: if hasattr(return_annotation, "model_json_schema"): response_model = return_annotation elif get_origin(return_annotation) is Annotated: # Handle Annotated return types args = get_args(return_annotation) if args: # Check if the first argument is a Pydantic model if hasattr(args[0], "model_json_schema"): response_model = args[0] # Check if the first argument is a union type elif get_origin(args[0]) is type(args[0]): # Union type union_args = get_args(args[0]) for arg in union_args: if hasattr(arg, "model_json_schema"): response_model = arg break elif get_origin(return_annotation) is type(return_annotation): # Union type # Handle union types - try to find the first Pydantic model args = get_args(return_annotation) for arg in args: if hasattr(arg, "model_json_schema"): response_model = arg break return request_model, response_model, query_parameters except Exception: # If we can't analyze the function signature, return None return None, None, [] def _make_type_safe_for_fastapi(type_hint) -> type: """ Make a type hint safe for FastAPI by converting problematic types to their base types. This handles cases like Literal["24h"] that cause forward reference errors. Also removes Union with None to avoid anyOf with type: 'null' schemas. """ # Handle Literal types that might cause issues if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Literal: args = get_args(type_hint) if args: # Get the type of the first literal value first_arg = args[0] if isinstance(first_arg, str): return str elif isinstance(first_arg, int): return int elif isinstance(first_arg, float): return float elif isinstance(first_arg, bool): return bool else: return type(first_arg) # Handle Union types (Python 3.10+ uses | syntax) origin = get_origin(type_hint) if origin is type(None) or (origin is type and type_hint is type(None)): # This is just None, return None return type_hint # Handle Union types (both old Union and new | syntax) if origin is type(type_hint) or (hasattr(type_hint, "__args__") and type_hint.__args__): # This is a union type, find the non-None type args = get_args(type_hint) non_none_types = [arg for arg in args if arg is not type(None) and arg is not None] if non_none_types: # Return the first non-None type to avoid anyOf with null return non_none_types[0] elif args: # If all args are None, return the first one return args[0] else: return type_hint # Not a union type, return as-is return type_hint def _generate_schema_for_type(type_hint) -> dict[str, Any]: """ Generate a JSON schema for a given type hint. This is a simplified version that handles basic types. """ # Handle Union types (e.g., str | None) if get_origin(type_hint) is type(None) or (get_origin(type_hint) is type and type_hint is type(None)): return {"type": "null"} # Handle list types if get_origin(type_hint) is list: args = get_args(type_hint) if args: item_type = args[0] return {"type": "array", "items": _generate_schema_for_type(item_type)} return {"type": "array"} # Handle basic types if type_hint is str: return {"type": "string"} elif type_hint is int: return {"type": "integer"} elif type_hint is float: return {"type": "number"} elif type_hint is bool: return {"type": "boolean"} elif type_hint is dict: return {"type": "object"} elif type_hint is list: return {"type": "array"} # For complex types, try to get the schema from Pydantic try: if hasattr(type_hint, "model_json_schema"): return type_hint.model_json_schema() elif hasattr(type_hint, "__name__"): return {"$ref": f"#/components/schemas/{type_hint.__name__}"} except Exception: pass # Fallback return {"type": "object"} def _add_llama_stack_extensions(openapi_schema: dict[str, Any], app: FastAPI) -> dict[str, Any]: """ Add Llama Stack specific extensions to the OpenAPI schema. This includes x-llama-stack-extra-body-params for ExtraBodyField parameters. """ # Get all API routes to find functions with ExtraBodyField parameters api_routes = get_all_api_routes() for api_name, routes in api_routes.items(): for route, webmethod in routes: # Extract path and method path = route.path methods = route.methods for method in methods: method_lower = method.lower() if method_lower in openapi_schema.get("paths", {}).get(path, {}): operation = openapi_schema["paths"][path][method_lower] # Try to find the actual function that implements this route # and extract its ExtraBodyField parameters extra_body_params = _find_extra_body_params_for_route(api_name, route, webmethod) if extra_body_params: operation["x-llama-stack-extra-body-params"] = extra_body_params return openapi_schema def _find_extra_body_params_for_route(api_name: str, route, webmethod) -> list[dict[str, Any]]: """ Find the actual function that implements a route and extract its ExtraBodyField parameters. """ try: # Try to get the actual function from the API protocol map from llama_stack.core.resolver import api_protocol_map # Look up the API implementation if api_name in api_protocol_map: _ = api_protocol_map[api_name] # Try to find the method that matches this route # This is a simplified approach - we'd need to map the route to the actual method # For now, we'll return an empty list to avoid hardcoding return [] return [] except Exception: # If we can't find the function, return empty list return [] def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Ensure all @json_schema_type decorated models are included in the OpenAPI schema. This finds all models with the _llama_stack_schema_type attribute and adds them to the schema. """ if "components" not in openapi_schema: openapi_schema["components"] = {} if "schemas" not in openapi_schema["components"]: openapi_schema["components"]["schemas"] = {} # Find all classes with the _llama_stack_schema_type attribute from llama_stack import apis # Get all modules in the apis package apis_modules = [] for module_name in dir(apis): if not module_name.startswith("_"): try: module = getattr(apis, module_name) if hasattr(module, "__file__"): apis_modules.append(module) except (ImportError, AttributeError): continue # Also check submodules for module in apis_modules: for attr_name in dir(module): if not attr_name.startswith("_"): try: attr = getattr(module, attr_name) if hasattr(attr, "__file__") and hasattr(attr, "__name__"): apis_modules.append(attr) except (ImportError, AttributeError): continue # Find all classes with the _llama_stack_schema_type attribute for module in apis_modules: for attr_name in dir(module): try: attr = getattr(module, attr_name) if ( hasattr(attr, "_llama_stack_schema_type") and hasattr(attr, "model_json_schema") and hasattr(attr, "__name__") ): schema_name = attr.__name__ if schema_name not in openapi_schema["components"]["schemas"]: try: schema = attr.model_json_schema() openapi_schema["components"]["schemas"][schema_name] = schema except Exception: # Skip if we can't generate the schema continue except (AttributeError, TypeError): continue # Also include any dynamic models that were created during endpoint generation # This is a workaround to ensure dynamic models appear in the schema global _dynamic_models if "_dynamic_models" in globals(): for model in _dynamic_models: try: schema_name = model.__name__ if schema_name not in openapi_schema["components"]["schemas"]: schema = model.model_json_schema() openapi_schema["components"]["schemas"][schema_name] = schema except Exception: # Skip if we can't generate the schema continue return openapi_schema def _fix_ref_references(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Fix $ref references to point to components/schemas instead of $defs. This prevents the YAML dumper from creating a root-level $defs section. """ def fix_refs(obj: Any) -> None: if isinstance(obj, dict): if "$ref" in obj and obj["$ref"].startswith("#/$defs/"): # Replace #/$defs/ with #/components/schemas/ obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/") for value in obj.values(): fix_refs(value) elif isinstance(obj, list): for item in obj: fix_refs(item) fix_refs(openapi_schema) return openapi_schema def _fix_anyof_with_null(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Fix anyOf schemas that contain type: 'null' by removing the null type and making the field optional through the required field instead. """ def fix_anyof(obj: Any) -> None: if isinstance(obj, dict): if "anyOf" in obj and isinstance(obj["anyOf"], list): # Check if anyOf contains type: 'null' has_null = any(item.get("type") == "null" for item in obj["anyOf"] if isinstance(item, dict)) if has_null: # Remove null types and keep only the non-null types non_null_types = [ item for item in obj["anyOf"] if not (isinstance(item, dict) and item.get("type") == "null") ] if len(non_null_types) == 1: # If only one non-null type remains, replace anyOf with that type obj.update(non_null_types[0]) if "anyOf" in obj: del obj["anyOf"] else: # Keep the anyOf but without null types obj["anyOf"] = non_null_types # Recursively process all values for value in obj.values(): fix_anyof(value) elif isinstance(obj, list): for item in obj: fix_anyof(item) fix_anyof(openapi_schema) return openapi_schema def _eliminate_defs_section(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Eliminate $defs section entirely by moving all definitions to components/schemas. This matches the structure of the old pyopenapi generator for oasdiff compatibility. """ if "components" not in openapi_schema: openapi_schema["components"] = {} if "schemas" not in openapi_schema["components"]: openapi_schema["components"]["schemas"] = {} # First pass: collect all $defs from anywhere in the schema defs_to_move = {} def collect_defs(obj: Any) -> None: if isinstance(obj, dict): if "$defs" in obj: # Collect $defs for later processing for def_name, def_schema in obj["$defs"].items(): if def_name not in defs_to_move: defs_to_move[def_name] = def_schema # Recursively process all values for value in obj.values(): collect_defs(value) elif isinstance(obj, list): for item in obj: collect_defs(item) # Collect all $defs collect_defs(openapi_schema) # Move all $defs to components/schemas for def_name, def_schema in defs_to_move.items(): if def_name not in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"][def_name] = def_schema # Also move any existing root-level $defs to components/schemas if "$defs" in openapi_schema: print(f"Found root-level $defs with {len(openapi_schema['$defs'])} items, moving to components/schemas") for def_name, def_schema in openapi_schema["$defs"].items(): if def_name not in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"][def_name] = def_schema # Remove the root-level $defs del openapi_schema["$defs"] # Second pass: remove all $defs sections from anywhere in the schema def remove_defs(obj: Any) -> None: if isinstance(obj, dict): if "$defs" in obj: del obj["$defs"] # Recursively process all values for value in obj.values(): remove_defs(value) elif isinstance(obj, list): for item in obj: remove_defs(item) # Remove all $defs sections remove_defs(openapi_schema) return openapi_schema def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Add standard error response definitions to the OpenAPI schema. Uses the actual Error model from the codebase for consistency. """ if "components" not in openapi_schema: openapi_schema["components"] = {} if "responses" not in openapi_schema["components"]: openapi_schema["components"]["responses"] = {} # Import the actual Error model try: from llama_stack.apis.datatypes import Error # Generate the Error schema using Pydantic error_schema = Error.model_json_schema() # Ensure the Error schema is in the components/schemas if "schemas" not in openapi_schema["components"]: openapi_schema["components"]["schemas"] = {} # Only add Error schema if it doesn't already exist if "Error" not in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"]["Error"] = error_schema except ImportError: # Fallback if we can't import the Error model error_schema = {"$ref": "#/components/schemas/Error"} # Define standard HTTP error responses error_responses = { 400: { "name": "BadRequest400", "description": "The request was invalid or malformed", "example": {"status": 400, "title": "Bad Request", "detail": "The request was invalid or malformed"}, }, 429: { "name": "TooManyRequests429", "description": "The client has sent too many requests in a given amount of time", "example": { "status": 429, "title": "Too Many Requests", "detail": "You have exceeded the rate limit. Please try again later.", }, }, 500: { "name": "InternalServerError500", "description": "The server encountered an unexpected error", "example": { "status": 500, "title": "Internal Server Error", "detail": "An unexpected error occurred. Our team has been notified.", }, }, } # Add each error response to the schema for _, error_info in error_responses.items(): response_name = error_info["name"] openapi_schema["components"]["responses"][response_name] = { "description": error_info["description"], "content": { "application/json": {"schema": {"$ref": "#/components/schemas/Error"}, "example": error_info["example"]} }, } # Add a default error response openapi_schema["components"]["responses"]["DefaultError"] = { "description": "An unexpected error occurred", "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Error"}}}, } return openapi_schema def _fix_path_parameters(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Fix path parameter resolution issues by adding explicit parameter definitions. Uses docstring descriptions if available. """ global _path_webmethod_map if "paths" not in openapi_schema: return openapi_schema for path, path_item in openapi_schema["paths"].items(): # Add parameters to each operation in this path for method in ["get", "post", "put", "delete", "patch", "head", "options"]: if method in path_item and isinstance(path_item[method], dict): operation = path_item[method] # Get webmethod for this path/method to extract parameter descriptions webmethod = _path_webmethod_map.get((path, method)) # Extract path parameters from the URL with descriptions from docstring path_params = _extract_path_parameters(path, webmethod) if not path_params: continue if "parameters" not in operation: operation["parameters"] = [] # Separate path and non-path parameters existing_params = operation["parameters"] non_path_params = [p for p in existing_params if p.get("in") != "path"] existing_path_params = {p.get("name"): p for p in existing_params if p.get("in") == "path"} # Build new parameters list: non-path params first, then path params in docstring order new_params = non_path_params.copy() # Add path parameters in docstring order for param in path_params: param_name = param["name"] if param_name in existing_path_params: # Update existing parameter description if we have a better one existing_param = existing_path_params[param_name] if param["description"] != f"Path parameter: {param_name}": existing_param["description"] = param["description"] new_params.append(existing_param) else: # Add new path parameter new_params.append(param) operation["parameters"] = new_params return openapi_schema def _extract_first_line_from_description(description: str) -> str: """ Extract all lines from a description string that don't start with docstring keywords. Stops at the first line that starts with :param:, :returns:, :raises:, etc. Preserves multiple lines and formatting. """ if not description: return description lines = description.split("\n") description_lines = [] for line in lines: stripped = line.strip() if not stripped: # Keep empty lines in the description to preserve formatting description_lines.append(line) continue if ( stripped.startswith(":param") or stripped.startswith(":returns") or stripped.startswith(":raises") or (stripped.startswith(":") and len(stripped) > 1 and stripped[1].isalpha()) ): break description_lines.append(line) # Join lines and strip trailing whitespace/newlines result = "\n".join(description_lines).rstrip() return result if result else description def _fix_component_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Fix component descriptions to only include the first line (summary), removing :param:, :returns:, and other docstring directives. """ if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]: return openapi_schema schemas = openapi_schema["components"]["schemas"] def fix_description_in_schema(schema_def: dict[str, Any]) -> None: if isinstance(schema_def, dict): if "description" in schema_def and isinstance(schema_def["description"], str): schema_def["description"] = _extract_first_line_from_description(schema_def["description"]) for value in schema_def.values(): fix_description_in_schema(value) elif isinstance(schema_def, list): for item in schema_def: fix_description_in_schema(item) for _, schema_def in schemas.items(): fix_description_in_schema(schema_def) return openapi_schema def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Fix common schema issues that cause OpenAPI validation problems. This includes converting exclusiveMinimum numbers to minimum values and fixing string fields with null defaults. """ if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]: return openapi_schema schemas = openapi_schema["components"]["schemas"] # Fix exclusiveMinimum issues for _, schema_def in schemas.items(): _fix_exclusive_minimum_in_schema(schema_def) _fix_all_null_defaults(schema_def) return openapi_schema def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI schema") -> bool: """ Validate an OpenAPI schema using openapi-spec-validator. Args: schema: The OpenAPI schema dictionary to validate schema_name: Name of the schema for error reporting Returns: True if valid, False otherwise Raises: OpenAPIValidationError: If validation fails """ try: validate_spec(schema) print(f"✅ {schema_name} is valid") return True except OpenAPISpecValidatorError as e: print(f"❌ {schema_name} validation failed:") print(f" {e}") return False except Exception as e: print(f"❌ {schema_name} validation error: {e}") return False def validate_schema_file(file_path: Path) -> bool: """ Validate an OpenAPI schema file (YAML or JSON). Args: file_path: Path to the schema file Returns: True if valid, False otherwise """ try: with open(file_path) as f: if file_path.suffix.lower() in [".yaml", ".yml"]: schema = yaml.safe_load(f) elif file_path.suffix.lower() == ".json": schema = json.load(f) else: print(f"❌ Unsupported file format: {file_path.suffix}") return False return validate_openapi_schema(schema, str(file_path)) except Exception as e: print(f"❌ Failed to read {file_path}: {e}") return False def _fix_exclusive_minimum_in_schema(obj: Any) -> None: """ Recursively fix exclusiveMinimum issues in a schema object. Converts exclusiveMinimum numbers to minimum values. """ if isinstance(obj, dict): # Check if this is a schema with exclusiveMinimum if "exclusiveMinimum" in obj and isinstance(obj["exclusiveMinimum"], int | float): # Convert exclusiveMinimum number to minimum obj["minimum"] = obj["exclusiveMinimum"] del obj["exclusiveMinimum"] # Recursively process all values for value in obj.values(): _fix_exclusive_minimum_in_schema(value) elif isinstance(obj, list): # Recursively process all items for item in obj: _fix_exclusive_minimum_in_schema(item) def _fix_string_fields_with_null_defaults(obj: Any) -> None: """ Recursively fix string fields that have default: null. This violates OpenAPI spec - string fields should either have a string default or be optional. """ if isinstance(obj, dict): # Check if this is a field definition with type: string and default: null if obj.get("type") == "string" and "default" in obj and obj["default"] is None: # Remove the default: null to make the field optional del obj["default"] # Add nullable: true to indicate the field can be null obj["nullable"] = True # Recursively process all values for value in obj.values(): _fix_string_fields_with_null_defaults(value) elif isinstance(obj, list): # Recursively process all items for item in obj: _fix_string_fields_with_null_defaults(item) def _fix_anyof_with_null_defaults(obj: Any) -> None: """ Recursively fix anyOf schemas that have default: null. This violates OpenAPI spec - anyOf fields should not have null defaults. """ if isinstance(obj, dict): # Check if this is a field definition with anyOf and default: null if "anyOf" in obj and "default" in obj and obj["default"] is None: # Remove the default: null to make the field optional del obj["default"] # Add nullable: true to indicate the field can be null obj["nullable"] = True # Recursively process all values for value in obj.values(): _fix_anyof_with_null_defaults(value) elif isinstance(obj, list): # Recursively process all items for item in obj: _fix_anyof_with_null_defaults(item) def _fix_all_null_defaults(obj: Any) -> None: """ Recursively fix all field types that have default: null. This violates OpenAPI spec - fields should not have null defaults. """ if isinstance(obj, dict): # Check if this is a field definition with default: null if "default" in obj and obj["default"] is None: # Remove the default: null to make the field optional del obj["default"] # Add nullable: true to indicate the field can be null obj["nullable"] = True # Recursively process all values for value in obj.values(): _fix_all_null_defaults(value) elif isinstance(obj, list): # Recursively process all items for item in obj: _fix_all_null_defaults(item) def _sort_paths_alphabetically(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Sort the paths in the OpenAPI schema by version prefix first, then alphabetically. Also sort HTTP methods alphabetically within each path. Version order: v1beta, v1alpha, v1 """ if "paths" not in openapi_schema: return openapi_schema def path_sort_key(path: str) -> tuple: """ Create a sort key that groups paths by version prefix first. Returns (version_priority, path) where version_priority: - 0 for v1beta - 1 for v1alpha - 2 for v1 - 3 for others """ if path.startswith("/v1beta/"): version_priority = 0 elif path.startswith("/v1alpha/"): version_priority = 1 elif path.startswith("/v1/"): version_priority = 2 else: version_priority = 3 return (version_priority, path) def sort_path_item(path_item: dict[str, Any]) -> dict[str, Any]: """Sort HTTP methods alphabetically within a path item.""" if not isinstance(path_item, dict): return path_item # Define the order of HTTP methods method_order = ["delete", "get", "head", "options", "patch", "post", "put", "trace"] # Create a new ordered dict with methods in alphabetical order sorted_path_item = {} # First add methods in the defined order for method in method_order: if method in path_item: sorted_path_item[method] = path_item[method] # Then add any other keys that aren't HTTP methods for key, value in path_item.items(): if key not in method_order: sorted_path_item[key] = value return sorted_path_item # Sort paths by version prefix first, then alphabetically # Also sort HTTP methods within each path sorted_paths = {} for path, path_item in sorted(openapi_schema["paths"].items(), key=lambda x: path_sort_key(x[0])): sorted_paths[path] = sort_path_item(path_item) openapi_schema["paths"] = sorted_paths return openapi_schema def _filter_schema_by_version( openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True ) -> dict[str, Any]: """ Filter OpenAPI schema by API version. Args: openapi_schema: The full OpenAPI schema stable_only: If True, return only /v1/ paths (stable). If False, return only /v1alpha/ and /v1beta/ paths (experimental). exclude_deprecated: If True, exclude deprecated endpoints from the result. Returns: Filtered OpenAPI schema """ filtered_schema = openapi_schema.copy() if "paths" not in filtered_schema: return filtered_schema # Filter paths based on version prefix and deprecated status filtered_paths = {} for path, path_item in filtered_schema["paths"].items(): # Check if path has any deprecated operations is_deprecated = _is_path_deprecated(path_item) # Skip deprecated endpoints if exclude_deprecated is True if exclude_deprecated and is_deprecated: continue if stable_only: # Only include /v1/ paths, exclude /v1alpha/ and /v1beta/ if path.startswith("/v1/") and not path.startswith("/v1alpha/") and not path.startswith("/v1beta/"): filtered_paths[path] = path_item else: # Only include /v1alpha/ and /v1beta/ paths, exclude /v1/ if path.startswith("/v1alpha/") or path.startswith("/v1beta/"): filtered_paths[path] = path_item filtered_schema["paths"] = filtered_paths # Filter schemas/components to only include ones referenced by filtered paths if "components" in filtered_schema and "schemas" in filtered_schema["components"]: # Find all schemas that are actually referenced by the filtered paths # Use the original schema to find all references, not the filtered one referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) # Only keep schemas that are referenced by the filtered paths filtered_schemas = {} for schema_name, schema_def in filtered_schema["components"]["schemas"].items(): if schema_name in referenced_schemas: filtered_schemas[schema_name] = schema_def filtered_schema["components"]["schemas"] = filtered_schemas # Preserve $defs section if it exists if "components" in openapi_schema and "$defs" in openapi_schema["components"]: if "components" not in filtered_schema: filtered_schema["components"] = {} filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"] print(f"Preserved $defs section with {len(openapi_schema['components']['$defs'])} items") else: print("No $defs section to preserve") return filtered_schema def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]: """ Find all schemas that are referenced by the filtered paths. This recursively traverses the path definitions to find all $ref references. """ referenced_schemas = set() # Traverse all filtered paths for _, path_item in filtered_paths.items(): if not isinstance(path_item, dict): continue # Check each HTTP method in the path for method in ["get", "post", "put", "delete", "patch", "head", "options"]: if method in path_item: operation = path_item[method] if isinstance(operation, dict): # Find all schema references in this operation referenced_schemas.update(_find_schema_refs_in_object(operation)) # Also check the responses section for schema references if "components" in openapi_schema and "responses" in openapi_schema["components"]: referenced_schemas.update(_find_schema_refs_in_object(openapi_schema["components"]["responses"])) # Also include schemas that are referenced by other schemas (transitive references) # This ensures we include all dependencies all_schemas = openapi_schema.get("components", {}).get("schemas", {}) additional_schemas = set() for schema_name in referenced_schemas: if schema_name in all_schemas: additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) # Keep adding transitive references until no new ones are found while additional_schemas: new_schemas = additional_schemas - referenced_schemas if not new_schemas: break referenced_schemas.update(new_schemas) additional_schemas = set() for schema_name in new_schemas: if schema_name in all_schemas: additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) return referenced_schemas def _find_schema_refs_in_object(obj: Any) -> set[str]: """ Recursively find all schema references ($ref) in an object. """ refs = set() if isinstance(obj, dict): for key, value in obj.items(): if key == "$ref" and isinstance(value, str) and value.startswith("#/components/schemas/"): schema_name = value.split("/")[-1] refs.add(schema_name) else: refs.update(_find_schema_refs_in_object(value)) elif isinstance(obj, list): for item in obj: refs.update(_find_schema_refs_in_object(item)) return refs def _is_path_deprecated(path_item: dict[str, Any]) -> bool: """ Check if a path item has any deprecated operations. """ if not isinstance(path_item, dict): return False # Check each HTTP method in the path item for method in ["get", "post", "put", "delete", "patch", "head", "options"]: if method in path_item: operation = path_item[method] if isinstance(operation, dict) and operation.get("deprecated", False): return True return False def _filter_deprecated_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Filter OpenAPI schema to include only deprecated endpoints. Includes all deprecated endpoints regardless of version (v1, v1alpha, v1beta). """ filtered_schema = openapi_schema.copy() if "paths" not in filtered_schema: return filtered_schema # Filter paths to only include deprecated ones filtered_paths = {} for path, path_item in filtered_schema["paths"].items(): if _is_path_deprecated(path_item): filtered_paths[path] = path_item filtered_schema["paths"] = filtered_paths return filtered_schema def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Filter OpenAPI schema to include both stable (v1) and experimental (v1alpha, v1beta) APIs. Excludes deprecated endpoints. This is used for the combined "stainless" spec. """ filtered_schema = openapi_schema.copy() if "paths" not in filtered_schema: return filtered_schema # Filter paths to include stable (v1) and experimental (v1alpha, v1beta), excluding deprecated filtered_paths = {} for path, path_item in filtered_schema["paths"].items(): # Check if path has any deprecated operations is_deprecated = _is_path_deprecated(path_item) # Skip deprecated endpoints if is_deprecated: continue # Include /v1/ paths (stable) if path.startswith("/v1/") and not path.startswith("/v1alpha/") and not path.startswith("/v1beta/"): filtered_paths[path] = path_item # Include /v1alpha/ and /v1beta/ paths (experimental) elif path.startswith("/v1alpha/") or path.startswith("/v1beta/"): filtered_paths[path] = path_item filtered_schema["paths"] = filtered_paths # Filter schemas/components to only include ones referenced by filtered paths if "components" in filtered_schema and "schemas" in filtered_schema["components"]: referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) filtered_schemas = {} for schema_name, schema_def in filtered_schema["components"]["schemas"].items(): if schema_name in referenced_schemas: filtered_schemas[schema_name] = schema_def filtered_schema["components"]["schemas"] = filtered_schemas return filtered_schema def generate_openapi_spec(output_dir: str, format: str = "yaml", include_examples: bool = True) -> dict[str, Any]: """ Generate OpenAPI specification using FastAPI's built-in method. Args: output_dir: Directory to save the generated files format: Output format ("yaml", "json", or "both") include_examples: Whether to include examples in the spec Returns: The generated OpenAPI specification as a dictionary """ # Create the FastAPI app app = create_llama_stack_app() # Generate the OpenAPI schema openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, servers=app.servers, ) # Debug: Check if there's a root-level $defs in the original schema if "$defs" in openapi_schema: print(f"Original schema has root-level $defs with {len(openapi_schema['$defs'])} items") else: print("Original schema has no root-level $defs") # Add Llama Stack specific extensions openapi_schema = _add_llama_stack_extensions(openapi_schema, app) # Add standard error responses openapi_schema = _add_error_responses(openapi_schema) # Ensure all @json_schema_type decorated models are included openapi_schema = _ensure_json_schema_types_included(openapi_schema) # Fix $ref references to point to components/schemas instead of $defs openapi_schema = _fix_ref_references(openapi_schema) # Debug: Check if there are any $ref references to $defs in the schema defs_refs = [] def find_defs_refs(obj: Any, path: str = "") -> None: if isinstance(obj, dict): if "$ref" in obj and obj["$ref"].startswith("#/$defs/"): defs_refs.append(f"{path}: {obj['$ref']}") for key, value in obj.items(): find_defs_refs(value, f"{path}.{key}" if path else key) elif isinstance(obj, list): for i, item in enumerate(obj): find_defs_refs(item, f"{path}[{i}]") find_defs_refs(openapi_schema) if defs_refs: print(f"Found {len(defs_refs)} $ref references to $defs in schema") for ref in defs_refs[:5]: # Show first 5 print(f" {ref}") else: print("No $ref references to $defs found in schema") # Note: Let Pydantic/FastAPI generate the correct, standards-compliant schema # Fields with default values should be optional according to OpenAPI standards # Fix anyOf schemas with type: 'null' to avoid oasdiff errors openapi_schema = _fix_anyof_with_null(openapi_schema) # Fix path parameter resolution issues openapi_schema = _fix_path_parameters(openapi_schema) # Eliminate $defs section entirely for oasdiff compatibility openapi_schema = _eliminate_defs_section(openapi_schema) # Fix component descriptions to only include first line (summary) openapi_schema = _fix_component_descriptions(openapi_schema) # Debug: Check if there's a root-level $defs after flattening if "$defs" in openapi_schema: print(f"After flattening: root-level $defs with {len(openapi_schema['$defs'])} items") else: print("After flattening: no root-level $defs") # Ensure all referenced schemas are included # DISABLED: This was using hardcoded schema generation. FastAPI should handle this automatically. # openapi_schema = _ensure_referenced_schemas(openapi_schema) # Control schema registration based on @json_schema_type decorator # Temporarily disabled to fix missing schema issues # openapi_schema = _control_schema_registration(openapi_schema) # Fix malformed schemas after all other processing # DISABLED: This was a hardcoded workaround. Using Pydantic's TypeAdapter instead. # _fix_malformed_schemas(openapi_schema) # Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs # Each spec needs its own deep copy of the full schema to avoid cross-contamination import copy stable_schema = _filter_schema_by_version(copy.deepcopy(openapi_schema), stable_only=True, exclude_deprecated=True) experimental_schema = _filter_schema_by_version( copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True ) deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema)) combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema)) # Update title and description for combined schema if "info" in combined_schema: combined_schema["info"]["title"] = "Llama Stack API - Stable & Experimental APIs" combined_schema["info"]["description"] = ( combined_schema["info"].get("description", "") + "\n\n**🔗 COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. " "Use stable APIs for production deployments and experimental APIs for testing new features." ) # Sort paths alphabetically for stable (v1 only) stable_schema = _sort_paths_alphabetically(stable_schema) # Sort paths by version prefix for experimental (v1beta, v1alpha) experimental_schema = _sort_paths_alphabetically(experimental_schema) # Sort paths by version prefix for deprecated deprecated_schema = _sort_paths_alphabetically(deprecated_schema) # Sort paths by version prefix for combined (stainless) combined_schema = _sort_paths_alphabetically(combined_schema) # Fix schema issues (like exclusiveMinimum -> minimum) for each spec stable_schema = _fix_schema_issues(stable_schema) experimental_schema = _fix_schema_issues(experimental_schema) deprecated_schema = _fix_schema_issues(deprecated_schema) combined_schema = _fix_schema_issues(combined_schema) # Validate the schemas print("\n🔍 Validating generated schemas...") stable_valid = validate_openapi_schema(stable_schema, "Stable schema") experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema") deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema") combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema") if not all([stable_valid, experimental_valid, deprecated_valid, combined_valid]): print("⚠️ Some schemas failed validation, but continuing with generation...") # Add any custom modifications here if needed if include_examples: # Add examples to the schema if needed pass # Ensure output directory exists output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save the stable specification if format in ["yaml", "both"]: yaml_path = output_path / "llama-stack-spec.yaml" # Use ruamel.yaml for better YAML formatting try: from ruamel.yaml import YAML yaml_writer = YAML() yaml_writer.default_flow_style = False yaml_writer.sort_keys = False yaml_writer.width = 4096 # Prevent line wrapping yaml_writer.allow_unicode = True with open(yaml_path, "w") as f: yaml_writer.dump(stable_schema, f) # Post-process the YAML file to remove $defs section and fix references # Re-read and re-write with ruamel.yaml with open(yaml_path) as f: yaml_content = f.read() if " $defs:" in yaml_content or "#/$defs/" in yaml_content: print("Post-processing YAML to remove $defs section") # Use string replacement to fix references directly if "#/$defs/" in yaml_content: refs_fixed = yaml_content.count("#/$defs/") yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/") print(f"Fixed {refs_fixed} $ref references using string replacement") # Parse using PyYAML safe_load first to avoid issues with custom types # This handles block scalars better during post-processing import yaml as pyyaml with open(yaml_path) as f: yaml_data = pyyaml.safe_load(f) # Move $defs to components/schemas if it exists if "$defs" in yaml_data: print(f"Found $defs section with {len(yaml_data['$defs'])} items") if "components" not in yaml_data: yaml_data["components"] = {} if "schemas" not in yaml_data["components"]: yaml_data["components"]["schemas"] = {} # Move all $defs to components/schemas for def_name, def_schema in yaml_data["$defs"].items(): yaml_data["components"]["schemas"][def_name] = def_schema # Remove the $defs section del yaml_data["$defs"] print("Moved $defs to components/schemas") # Write the modified YAML back with ruamel.yaml with open(yaml_path, "w") as f: yaml_writer.dump(yaml_data, f) print("Updated YAML file") except ImportError: # Fallback to standard yaml if ruamel.yaml is not available with open(yaml_path, "w") as f: yaml.dump(stable_schema, f, default_flow_style=False, sort_keys=False) print(f"✅ Generated YAML (stable): {yaml_path}") experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml" with open(experimental_yaml_path, "w") as f: yaml.dump(experimental_schema, f, default_flow_style=False, sort_keys=False) print(f"✅ Generated YAML (experimental): {experimental_yaml_path}") deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml" with open(deprecated_yaml_path, "w") as f: yaml.dump(deprecated_schema, f, default_flow_style=False, sort_keys=False) print(f"✅ Generated YAML (deprecated): {deprecated_yaml_path}") # Generate combined (stainless) spec stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml" try: from ruamel.yaml import YAML yaml_writer = YAML() yaml_writer.default_flow_style = False yaml_writer.sort_keys = False yaml_writer.width = 4096 # Prevent line wrapping yaml_writer.allow_unicode = True with open(stainless_yaml_path, "w") as f: yaml_writer.dump(combined_schema, f) except ImportError: # Fallback to standard yaml if ruamel.yaml is not available with open(stainless_yaml_path, "w") as f: yaml.dump(combined_schema, f, default_flow_style=False, sort_keys=False) print(f"✅ Generated YAML (stainless/combined): {stainless_yaml_path}") if format in ["json", "both"]: json_path = output_path / "llama-stack-spec.json" with open(json_path, "w") as f: json.dump(stable_schema, f, indent=2) print(f"✅ Generated JSON (stable): {json_path}") experimental_json_path = output_path / "experimental-llama-stack-spec.json" with open(experimental_json_path, "w") as f: json.dump(experimental_schema, f, indent=2) print(f"✅ Generated JSON (experimental): {experimental_json_path}") deprecated_json_path = output_path / "deprecated-llama-stack-spec.json" with open(deprecated_json_path, "w") as f: json.dump(deprecated_schema, f, indent=2) print(f"✅ Generated JSON (deprecated): {deprecated_json_path}") stainless_json_path = output_path / "stainless-llama-stack-spec.json" with open(stainless_json_path, "w") as f: json.dump(combined_schema, f, indent=2) print(f"✅ Generated JSON (stainless/combined): {stainless_json_path}") # Generate HTML documentation html_path = output_path / "llama-stack-spec.html" generate_html_docs(stable_schema, html_path) print(f"✅ Generated HTML: {html_path}") experimental_html_path = output_path / "experimental-llama-stack-spec.html" generate_html_docs(experimental_schema, experimental_html_path, spec_file="experimental-llama-stack-spec.yaml") print(f"✅ Generated HTML (experimental): {experimental_html_path}") deprecated_html_path = output_path / "deprecated-llama-stack-spec.html" generate_html_docs(deprecated_schema, deprecated_html_path, spec_file="deprecated-llama-stack-spec.yaml") print(f"✅ Generated HTML (deprecated): {deprecated_html_path}") stainless_html_path = output_path / "stainless-llama-stack-spec.html" generate_html_docs(combined_schema, stainless_html_path, spec_file="stainless-llama-stack-spec.yaml") print(f"✅ Generated HTML (stainless/combined): {stainless_html_path}") return stable_schema def generate_html_docs( openapi_schema: dict[str, Any], output_path: Path, spec_file: str = "llama-stack-spec.yaml" ) -> None: """Generate HTML documentation using ReDoc.""" html_template = f"""