llama-stack-mirror/scripts/fastapi_generator.py
Sébastien Han 357be98279
wip2
Signed-off-by: Sébastien Han <seb@redhat.com>
2025-11-04 10:23:07 +01:00

1896 lines
74 KiB
Python
Executable file

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
FastAPI-based OpenAPI generator for Llama Stack.
"""
import inspect
import json
from pathlib import Path
from typing import Annotated, Any, Literal, get_args, get_origin
import yaml
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from openapi_spec_validator import validate_spec
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
from llama_stack.apis.datatypes import Api
from llama_stack.core.resolver import api_protocol_map
# Import the existing route discovery system
from llama_stack.core.server.routes import get_all_api_routes
# Global list to store dynamic models created during endpoint generation
_dynamic_models = []
# Global mapping from (path, method) to webmethod for parameter description extraction
_path_webmethod_map: dict[tuple[str, str], Any] = {}
def _get_all_api_routes_with_functions():
"""
Get all API routes with their actual function references.
This is a modified version of get_all_api_routes that includes the function.
"""
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
continue
# Create routes for each webmethod decorator
for webmethod in webmethods:
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
# Store the function reference in the webmethod
webmethod.func = method
routes.append((Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod))
apis[api] = routes
return apis
def create_llama_stack_app() -> FastAPI:
"""
Create a FastAPI app that represents the Llama Stack API.
This uses the existing route discovery system to automatically find all routes.
"""
app = FastAPI(
title="Llama Stack API",
description="A comprehensive API for building and deploying AI applications",
version="1.0.0",
servers=[
{"url": "https://api.llamastack.com", "description": "Production server"},
{"url": "https://staging-api.llamastack.com", "description": "Staging server"},
],
)
# Get all API routes using the modified system that includes functions
api_routes = _get_all_api_routes_with_functions()
# Create FastAPI routes from the discovered routes
for _, routes in api_routes.items():
for route, webmethod in routes:
# Store mapping for later use in parameter description extraction
for method in route.methods:
_path_webmethod_map[(route.path, method.lower())] = webmethod
# Convert the route to a FastAPI endpoint
_create_fastapi_endpoint(app, route, webmethod)
return app
def _extract_path_parameters(path: str, webmethod=None) -> list[dict[str, Any]]:
"""
Extract path parameters from a URL path and return them as OpenAPI parameter definitions.
Parameters are returned in the order they appear in the docstring if available,
otherwise in the order they appear in the path.
Args:
path: URL path with parameters like /v1/batches/{batch_id}/cancel
webmethod: Optional webmethod to extract parameter descriptions from docstring
Returns:
List of parameter definitions for OpenAPI
"""
import re
# Find all path parameters in the format {param} or {param:type}
param_pattern = r"\{([^}:]+)(?::[^}]+)?\}"
path_params = set(re.findall(param_pattern, path))
# Extract parameter descriptions and order from docstring if available
param_descriptions = {}
docstring_param_order = []
if webmethod:
func = getattr(webmethod, "func", None)
if func and func.__doc__:
docstring = func.__doc__
lines = docstring.split("\n")
for line in lines:
line = line.strip()
if line.startswith(":param "):
# Extract parameter name and description
# Format: :param param_name: description
parts = line[7:].split(":", 1)
if len(parts) == 2:
param_name = parts[0].strip()
description = parts[1].strip()
# Only track path parameters that exist in the path
if param_name in path_params:
if description:
param_descriptions[param_name] = description
if param_name not in docstring_param_order:
docstring_param_order.append(param_name)
# Build parameters list preserving docstring order for path parameters found in docstring,
# then add any remaining path parameters in path order
parameters = []
# First add parameters in docstring order
for param_name in docstring_param_order:
if param_name in path_params:
description = param_descriptions.get(param_name, f"Path parameter: {param_name}")
parameters.append(
{
"name": param_name,
"in": "path",
"required": True,
"schema": {"type": "string"},
"description": description,
}
)
# Then add any path parameters not in docstring, in path order
path_param_list = re.findall(param_pattern, path)
for param_name in path_param_list:
if param_name not in docstring_param_order:
description = param_descriptions.get(param_name, f"Path parameter: {param_name}")
parameters.append(
{
"name": param_name,
"in": "path",
"required": True,
"schema": {"type": "string"},
"description": description,
}
)
return parameters
def _create_fastapi_endpoint(app: FastAPI, route, webmethod):
"""
Create a FastAPI endpoint from a discovered route and webmethod.
This creates endpoints with actual Pydantic models for proper schema generation.
"""
# Extract route information
path = route.path
methods = route.methods
name = route.name
# Convert path parameters from {param} to {param:path} format for FastAPI
fastapi_path = path.replace("{", "{").replace("}", "}")
# Try to find actual models for this endpoint
request_model, response_model, query_parameters = _find_models_for_endpoint(webmethod)
# Debug: Print info for safety endpoints
if "safety" in webmethod.route or "shield" in webmethod.route:
print(
f"Debug: {webmethod.route} - request_model: {request_model}, response_model: {response_model}, query_parameters: {query_parameters}"
)
# Extract summary and response description from webmethod docstring
summary = _extract_summary_from_docstring(webmethod)
response_description = _extract_response_description_from_docstring(webmethod, response_model)
# Create endpoint function with proper typing
if request_model and response_model:
# POST/PUT request with request body
async def typed_endpoint(request: request_model) -> response_model:
"""Typed endpoint for proper schema generation."""
return response_model()
endpoint_func = typed_endpoint
elif response_model and query_parameters:
# Check if this is a POST/PUT endpoint with individual parameters
# For POST/PUT, individual parameters should go in request body, not query params
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
if is_post_put:
# POST/PUT with individual parameters - create a request body model
try:
from pydantic import create_model
# Create a dynamic Pydantic model for the request body
field_definitions = {}
for param_name, param_type, default_value in query_parameters:
# Handle complex types that might cause issues with create_model
safe_type = _make_type_safe_for_fastapi(param_type)
if default_value is None:
field_definitions[param_name] = (safe_type, ...) # Required field
else:
field_definitions[param_name] = (safe_type, default_value) # Optional field with default
# Create the request model dynamically
# Clean up the route name to create a valid schema name
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
model_name = f"{clean_route}_Request"
print(f"Debug: Creating model {model_name} with fields: {field_definitions}")
request_model = create_model(model_name, **field_definitions)
print(f"Debug: Successfully created model {model_name}")
# Store the dynamic model in the global list for schema inclusion
_dynamic_models.append(request_model)
# Create endpoint with request body
async def typed_endpoint(request: request_model) -> response_model:
"""Typed endpoint for proper schema generation."""
return response_model()
# Set the function signature to ensure FastAPI recognizes the request model
typed_endpoint.__annotations__ = {"request": request_model, "return": response_model}
endpoint_func = typed_endpoint
except Exception as e:
# If dynamic model creation fails, fall back to query parameters
print(f"Warning: Failed to create dynamic request model for {webmethod.route}: {e}")
print(f" Query parameters: {query_parameters}")
# Fall through to the query parameter handling
pass
if not is_post_put:
# GET with query parameters - create a function with the actual query parameters
def create_query_endpoint_func():
# Build the function signature dynamically
import inspect
# Create parameter annotations
param_annotations = {}
param_defaults = {}
for param_name, param_type, default_value in query_parameters:
# Handle problematic type annotations that cause FastAPI issues
safe_type = _make_type_safe_for_fastapi(param_type)
param_annotations[param_name] = safe_type
if default_value is not None:
param_defaults[param_name] = default_value
# Create the function with the correct signature
def create_endpoint_func():
# Sort parameters so that required parameters come before optional ones
# Parameters with None default are required, others are optional
sorted_params = sorted(
query_parameters,
key=lambda x: (x[2] is not None, x[0]), # False (required) comes before True (optional)
)
# Create the function signature
sig = inspect.Signature(
[
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default_value if default_value is not None else inspect.Parameter.empty,
annotation=param_annotations[param_name],
)
for param_name, param_type, default_value in sorted_params
]
)
# Create a simple function without **kwargs
async def query_endpoint():
"""Query endpoint for proper schema generation."""
return response_model()
# Set the signature and annotations
query_endpoint.__signature__ = sig
query_endpoint.__annotations__ = param_annotations
return query_endpoint
return create_endpoint_func()
endpoint_func = create_query_endpoint_func()
elif response_model:
# Response-only endpoint (no parameters)
async def response_only_endpoint() -> response_model:
"""Response-only endpoint for proper schema generation."""
return response_model()
endpoint_func = response_only_endpoint
else:
# Fallback to generic endpoint
async def generic_endpoint(*args, **kwargs):
"""Generic endpoint - this would be replaced with actual implementation."""
return {"message": f"Endpoint {name} not implemented in OpenAPI generator"}
endpoint_func = generic_endpoint
# Add the endpoint to the FastAPI app
is_deprecated = webmethod.deprecated or False
route_kwargs = {
"name": name,
"tags": [_get_tag_from_api(webmethod)],
"deprecated": is_deprecated,
"responses": {
200: {
"description": response_description,
"content": {
"application/json": {
"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {}
}
},
},
400: {"$ref": "#/components/responses/BadRequest400"},
429: {"$ref": "#/components/responses/TooManyRequests429"},
500: {"$ref": "#/components/responses/InternalServerError500"},
"default": {"$ref": "#/components/responses/DefaultError"},
},
}
if summary:
route_kwargs["summary"] = summary
for method in methods:
if method.upper() == "GET":
app.get(fastapi_path, **route_kwargs)(endpoint_func)
elif method.upper() == "POST":
app.post(fastapi_path, **route_kwargs)(endpoint_func)
elif method.upper() == "PUT":
app.put(fastapi_path, **route_kwargs)(endpoint_func)
elif method.upper() == "DELETE":
app.delete(fastapi_path, **route_kwargs)(endpoint_func)
elif method.upper() == "PATCH":
app.patch(fastapi_path, **route_kwargs)(endpoint_func)
def _extract_summary_from_docstring(webmethod) -> str | None:
"""
Extract summary from the actual function docstring.
The summary is typically the first non-empty line of the docstring,
before any :param:, :returns:, or other docstring field markers.
"""
func = getattr(webmethod, "func", None)
if not func:
return None
docstring = func.__doc__ or ""
if not docstring:
return None
lines = docstring.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
if line.startswith(":param:") or line.startswith(":returns:") or line.startswith(":raises:"):
break
return line
return None
def _extract_response_description_from_docstring(webmethod, response_model) -> str:
"""
Extract response description from the actual function docstring.
Looks for :returns: in the docstring and uses that as the description.
"""
func = getattr(webmethod, "func", None)
if not func:
return "Successful Response"
docstring = func.__doc__ or ""
lines = docstring.split("\n")
for line in lines:
line = line.strip()
if line.startswith(":returns:"):
description = line[9:].strip()
if description:
return description
return "Successful Response"
def _get_tag_from_api(webmethod) -> str:
"""Extract a tag name from the webmethod for API grouping."""
# Extract API name from the route path
if webmethod.level:
return webmethod.level.replace("/", "").title()
return "API"
def _find_models_for_endpoint(webmethod) -> tuple[type | None, type | None, list[tuple[str, type, Any]]]:
"""
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
This uses the webmethod's function to determine the correct models dynamically.
Returns:
tuple: (request_model, response_model, query_parameters)
where query_parameters is a list of (name, type, default_value) tuples
"""
try:
# Get the actual function from the webmethod
func = getattr(webmethod, "func", None)
if not func:
return None, None, []
# Analyze the function signature
sig = inspect.signature(func)
# Find request model (first parameter that's not 'self')
request_model = None
query_parameters = []
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
# Skip *args and **kwargs parameters - these are not real API parameters
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
continue
# Check if it's a Pydantic model (for POST/PUT requests)
param_type = param.annotation
if hasattr(param_type, "model_json_schema"):
request_model = param_type
break
elif get_origin(param_type) is Annotated:
# Handle Annotated types - get the base type
args = get_args(param_type)
if args and hasattr(args[0], "model_json_schema"):
request_model = args[0]
break
else:
# This is likely a query parameter for GET requests
# Store the parameter info for later use
default_value = param.default if param.default != inspect.Parameter.empty else None
# Extract the base type from union types (e.g., str | None -> str)
# Also make it safe for FastAPI to avoid forward reference issues
base_type = _make_type_safe_for_fastapi(param_type)
query_parameters.append((param_name, base_type, default_value))
# Find response model from return annotation
response_model = None
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
if hasattr(return_annotation, "model_json_schema"):
response_model = return_annotation
elif get_origin(return_annotation) is Annotated:
# Handle Annotated return types
args = get_args(return_annotation)
if args:
# Check if the first argument is a Pydantic model
if hasattr(args[0], "model_json_schema"):
response_model = args[0]
# Check if the first argument is a union type
elif get_origin(args[0]) is type(args[0]): # Union type
union_args = get_args(args[0])
for arg in union_args:
if hasattr(arg, "model_json_schema"):
response_model = arg
break
elif get_origin(return_annotation) is type(return_annotation): # Union type
# Handle union types - try to find the first Pydantic model
args = get_args(return_annotation)
for arg in args:
if hasattr(arg, "model_json_schema"):
response_model = arg
break
return request_model, response_model, query_parameters
except Exception:
# If we can't analyze the function signature, return None
return None, None, []
def _make_type_safe_for_fastapi(type_hint) -> type:
"""
Make a type hint safe for FastAPI by converting problematic types to their base types.
This handles cases like Literal["24h"] that cause forward reference errors.
Also removes Union with None to avoid anyOf with type: 'null' schemas.
"""
# Handle Literal types that might cause issues
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Literal:
args = get_args(type_hint)
if args:
# Get the type of the first literal value
first_arg = args[0]
if isinstance(first_arg, str):
return str
elif isinstance(first_arg, int):
return int
elif isinstance(first_arg, float):
return float
elif isinstance(first_arg, bool):
return bool
else:
return type(first_arg)
# Handle Union types (Python 3.10+ uses | syntax)
origin = get_origin(type_hint)
if origin is type(None) or (origin is type and type_hint is type(None)):
# This is just None, return None
return type_hint
# Handle Union types (both old Union and new | syntax)
if origin is type(type_hint) or (hasattr(type_hint, "__args__") and type_hint.__args__):
# This is a union type, find the non-None type
args = get_args(type_hint)
non_none_types = [arg for arg in args if arg is not type(None) and arg is not None]
if non_none_types:
# Return the first non-None type to avoid anyOf with null
return non_none_types[0]
elif args:
# If all args are None, return the first one
return args[0]
else:
return type_hint
# Not a union type, return as-is
return type_hint
def _generate_schema_for_type(type_hint) -> dict[str, Any]:
"""
Generate a JSON schema for a given type hint.
This is a simplified version that handles basic types.
"""
# Handle Union types (e.g., str | None)
if get_origin(type_hint) is type(None) or (get_origin(type_hint) is type and type_hint is type(None)):
return {"type": "null"}
# Handle list types
if get_origin(type_hint) is list:
args = get_args(type_hint)
if args:
item_type = args[0]
return {"type": "array", "items": _generate_schema_for_type(item_type)}
return {"type": "array"}
# Handle basic types
if type_hint is str:
return {"type": "string"}
elif type_hint is int:
return {"type": "integer"}
elif type_hint is float:
return {"type": "number"}
elif type_hint is bool:
return {"type": "boolean"}
elif type_hint is dict:
return {"type": "object"}
elif type_hint is list:
return {"type": "array"}
# For complex types, try to get the schema from Pydantic
try:
if hasattr(type_hint, "model_json_schema"):
return type_hint.model_json_schema()
elif hasattr(type_hint, "__name__"):
return {"$ref": f"#/components/schemas/{type_hint.__name__}"}
except Exception:
pass
# Fallback
return {"type": "object"}
def _add_llama_stack_extensions(openapi_schema: dict[str, Any], app: FastAPI) -> dict[str, Any]:
"""
Add Llama Stack specific extensions to the OpenAPI schema.
This includes x-llama-stack-extra-body-params for ExtraBodyField parameters.
"""
# Get all API routes to find functions with ExtraBodyField parameters
api_routes = get_all_api_routes()
for api_name, routes in api_routes.items():
for route, webmethod in routes:
# Extract path and method
path = route.path
methods = route.methods
for method in methods:
method_lower = method.lower()
if method_lower in openapi_schema.get("paths", {}).get(path, {}):
operation = openapi_schema["paths"][path][method_lower]
# Try to find the actual function that implements this route
# and extract its ExtraBodyField parameters
extra_body_params = _find_extra_body_params_for_route(api_name, route, webmethod)
if extra_body_params:
operation["x-llama-stack-extra-body-params"] = extra_body_params
return openapi_schema
def _find_extra_body_params_for_route(api_name: str, route, webmethod) -> list[dict[str, Any]]:
"""
Find the actual function that implements a route and extract its ExtraBodyField parameters.
"""
try:
# Try to get the actual function from the API protocol map
from llama_stack.core.resolver import api_protocol_map
# Look up the API implementation
if api_name in api_protocol_map:
_ = api_protocol_map[api_name]
# Try to find the method that matches this route
# This is a simplified approach - we'd need to map the route to the actual method
# For now, we'll return an empty list to avoid hardcoding
return []
return []
except Exception:
# If we can't find the function, return empty list
return []
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Ensure all @json_schema_type decorated models are included in the OpenAPI schema.
This finds all models with the _llama_stack_schema_type attribute and adds them to the schema.
"""
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Find all classes with the _llama_stack_schema_type attribute
from llama_stack import apis
# Get all modules in the apis package
apis_modules = []
for module_name in dir(apis):
if not module_name.startswith("_"):
try:
module = getattr(apis, module_name)
if hasattr(module, "__file__"):
apis_modules.append(module)
except (ImportError, AttributeError):
continue
# Also check submodules
for module in apis_modules:
for attr_name in dir(module):
if not attr_name.startswith("_"):
try:
attr = getattr(module, attr_name)
if hasattr(attr, "__file__") and hasattr(attr, "__name__"):
apis_modules.append(attr)
except (ImportError, AttributeError):
continue
# Find all classes with the _llama_stack_schema_type attribute
for module in apis_modules:
for attr_name in dir(module):
try:
attr = getattr(module, attr_name)
if (
hasattr(attr, "_llama_stack_schema_type")
and hasattr(attr, "model_json_schema")
and hasattr(attr, "__name__")
):
schema_name = attr.__name__
if schema_name not in openapi_schema["components"]["schemas"]:
try:
schema = attr.model_json_schema()
openapi_schema["components"]["schemas"][schema_name] = schema
except Exception:
# Skip if we can't generate the schema
continue
except (AttributeError, TypeError):
continue
# Also include any dynamic models that were created during endpoint generation
# This is a workaround to ensure dynamic models appear in the schema
global _dynamic_models
if "_dynamic_models" in globals():
for model in _dynamic_models:
try:
schema_name = model.__name__
if schema_name not in openapi_schema["components"]["schemas"]:
schema = model.model_json_schema()
openapi_schema["components"]["schemas"][schema_name] = schema
except Exception:
# Skip if we can't generate the schema
continue
return openapi_schema
def _fix_ref_references(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Fix $ref references to point to components/schemas instead of $defs.
This prevents the YAML dumper from creating a root-level $defs section.
"""
def fix_refs(obj: Any) -> None:
if isinstance(obj, dict):
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
# Replace #/$defs/ with #/components/schemas/
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
for value in obj.values():
fix_refs(value)
elif isinstance(obj, list):
for item in obj:
fix_refs(item)
fix_refs(openapi_schema)
return openapi_schema
def _fix_anyof_with_null(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Fix anyOf schemas that contain type: 'null' by removing the null type
and making the field optional through the required field instead.
"""
def fix_anyof(obj: Any) -> None:
if isinstance(obj, dict):
if "anyOf" in obj and isinstance(obj["anyOf"], list):
# Check if anyOf contains type: 'null'
has_null = any(item.get("type") == "null" for item in obj["anyOf"] if isinstance(item, dict))
if has_null:
# Remove null types and keep only the non-null types
non_null_types = [
item for item in obj["anyOf"] if not (isinstance(item, dict) and item.get("type") == "null")
]
if len(non_null_types) == 1:
# If only one non-null type remains, replace anyOf with that type
obj.update(non_null_types[0])
if "anyOf" in obj:
del obj["anyOf"]
else:
# Keep the anyOf but without null types
obj["anyOf"] = non_null_types
# Recursively process all values
for value in obj.values():
fix_anyof(value)
elif isinstance(obj, list):
for item in obj:
fix_anyof(item)
fix_anyof(openapi_schema)
return openapi_schema
def _eliminate_defs_section(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Eliminate $defs section entirely by moving all definitions to components/schemas.
This matches the structure of the old pyopenapi generator for oasdiff compatibility.
"""
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# First pass: collect all $defs from anywhere in the schema
defs_to_move = {}
def collect_defs(obj: Any) -> None:
if isinstance(obj, dict):
if "$defs" in obj:
# Collect $defs for later processing
for def_name, def_schema in obj["$defs"].items():
if def_name not in defs_to_move:
defs_to_move[def_name] = def_schema
# Recursively process all values
for value in obj.values():
collect_defs(value)
elif isinstance(obj, list):
for item in obj:
collect_defs(item)
# Collect all $defs
collect_defs(openapi_schema)
# Move all $defs to components/schemas
for def_name, def_schema in defs_to_move.items():
if def_name not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"][def_name] = def_schema
# Also move any existing root-level $defs to components/schemas
if "$defs" in openapi_schema:
print(f"Found root-level $defs with {len(openapi_schema['$defs'])} items, moving to components/schemas")
for def_name, def_schema in openapi_schema["$defs"].items():
if def_name not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"][def_name] = def_schema
# Remove the root-level $defs
del openapi_schema["$defs"]
# Second pass: remove all $defs sections from anywhere in the schema
def remove_defs(obj: Any) -> None:
if isinstance(obj, dict):
if "$defs" in obj:
del obj["$defs"]
# Recursively process all values
for value in obj.values():
remove_defs(value)
elif isinstance(obj, list):
for item in obj:
remove_defs(item)
# Remove all $defs sections
remove_defs(openapi_schema)
return openapi_schema
def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Add standard error response definitions to the OpenAPI schema.
Uses the actual Error model from the codebase for consistency.
"""
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "responses" not in openapi_schema["components"]:
openapi_schema["components"]["responses"] = {}
# Import the actual Error model
try:
from llama_stack.apis.datatypes import Error
# Generate the Error schema using Pydantic
error_schema = Error.model_json_schema()
# Ensure the Error schema is in the components/schemas
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Only add Error schema if it doesn't already exist
if "Error" not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"]["Error"] = error_schema
except ImportError:
# Fallback if we can't import the Error model
error_schema = {"$ref": "#/components/schemas/Error"}
# Define standard HTTP error responses
error_responses = {
400: {
"name": "BadRequest400",
"description": "The request was invalid or malformed",
"example": {"status": 400, "title": "Bad Request", "detail": "The request was invalid or malformed"},
},
429: {
"name": "TooManyRequests429",
"description": "The client has sent too many requests in a given amount of time",
"example": {
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
},
},
500: {
"name": "InternalServerError500",
"description": "The server encountered an unexpected error",
"example": {
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
},
},
}
# Add each error response to the schema
for _, error_info in error_responses.items():
response_name = error_info["name"]
openapi_schema["components"]["responses"][response_name] = {
"description": error_info["description"],
"content": {
"application/json": {"schema": {"$ref": "#/components/schemas/Error"}, "example": error_info["example"]}
},
}
# Add a default error response
openapi_schema["components"]["responses"]["DefaultError"] = {
"description": "An unexpected error occurred",
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Error"}}},
}
return openapi_schema
def _fix_path_parameters(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Fix path parameter resolution issues by adding explicit parameter definitions.
Uses docstring descriptions if available.
"""
global _path_webmethod_map
if "paths" not in openapi_schema:
return openapi_schema
for path, path_item in openapi_schema["paths"].items():
# Add parameters to each operation in this path
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method in path_item and isinstance(path_item[method], dict):
operation = path_item[method]
# Get webmethod for this path/method to extract parameter descriptions
webmethod = _path_webmethod_map.get((path, method))
# Extract path parameters from the URL with descriptions from docstring
path_params = _extract_path_parameters(path, webmethod)
if not path_params:
continue
if "parameters" not in operation:
operation["parameters"] = []
# Separate path and non-path parameters
existing_params = operation["parameters"]
non_path_params = [p for p in existing_params if p.get("in") != "path"]
existing_path_params = {p.get("name"): p for p in existing_params if p.get("in") == "path"}
# Build new parameters list: non-path params first, then path params in docstring order
new_params = non_path_params.copy()
# Add path parameters in docstring order
for param in path_params:
param_name = param["name"]
if param_name in existing_path_params:
# Update existing parameter description if we have a better one
existing_param = existing_path_params[param_name]
if param["description"] != f"Path parameter: {param_name}":
existing_param["description"] = param["description"]
new_params.append(existing_param)
else:
# Add new path parameter
new_params.append(param)
operation["parameters"] = new_params
return openapi_schema
def _extract_first_line_from_description(description: str) -> str:
"""
Extract all lines from a description string that don't start with docstring keywords.
Stops at the first line that starts with :param:, :returns:, :raises:, etc.
Preserves multiple lines and formatting.
"""
if not description:
return description
lines = description.split("\n")
description_lines = []
for line in lines:
stripped = line.strip()
if not stripped:
# Keep empty lines in the description to preserve formatting
description_lines.append(line)
continue
if (
stripped.startswith(":param")
or stripped.startswith(":returns")
or stripped.startswith(":raises")
or (stripped.startswith(":") and len(stripped) > 1 and stripped[1].isalpha())
):
break
description_lines.append(line)
# Join lines and strip trailing whitespace/newlines
result = "\n".join(description_lines).rstrip()
return result if result else description
def _fix_component_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Fix component descriptions to only include the first line (summary),
removing :param:, :returns:, and other docstring directives.
"""
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
return openapi_schema
schemas = openapi_schema["components"]["schemas"]
def fix_description_in_schema(schema_def: dict[str, Any]) -> None:
if isinstance(schema_def, dict):
if "description" in schema_def and isinstance(schema_def["description"], str):
schema_def["description"] = _extract_first_line_from_description(schema_def["description"])
for value in schema_def.values():
fix_description_in_schema(value)
elif isinstance(schema_def, list):
for item in schema_def:
fix_description_in_schema(item)
for _, schema_def in schemas.items():
fix_description_in_schema(schema_def)
return openapi_schema
def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Fix common schema issues that cause OpenAPI validation problems.
This includes converting exclusiveMinimum numbers to minimum values and fixing string fields with null defaults.
"""
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
return openapi_schema
schemas = openapi_schema["components"]["schemas"]
# Fix exclusiveMinimum issues
for _, schema_def in schemas.items():
_fix_exclusive_minimum_in_schema(schema_def)
_fix_all_null_defaults(schema_def)
return openapi_schema
def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI schema") -> bool:
"""
Validate an OpenAPI schema using openapi-spec-validator.
Args:
schema: The OpenAPI schema dictionary to validate
schema_name: Name of the schema for error reporting
Returns:
True if valid, False otherwise
Raises:
OpenAPIValidationError: If validation fails
"""
try:
validate_spec(schema)
print(f"{schema_name} is valid")
return True
except OpenAPISpecValidatorError as e:
print(f"{schema_name} validation failed:")
print(f" {e}")
return False
except Exception as e:
print(f"{schema_name} validation error: {e}")
return False
def validate_schema_file(file_path: Path) -> bool:
"""
Validate an OpenAPI schema file (YAML or JSON).
Args:
file_path: Path to the schema file
Returns:
True if valid, False otherwise
"""
try:
with open(file_path) as f:
if file_path.suffix.lower() in [".yaml", ".yml"]:
schema = yaml.safe_load(f)
elif file_path.suffix.lower() == ".json":
schema = json.load(f)
else:
print(f"❌ Unsupported file format: {file_path.suffix}")
return False
return validate_openapi_schema(schema, str(file_path))
except Exception as e:
print(f"❌ Failed to read {file_path}: {e}")
return False
def _fix_exclusive_minimum_in_schema(obj: Any) -> None:
"""
Recursively fix exclusiveMinimum issues in a schema object.
Converts exclusiveMinimum numbers to minimum values.
"""
if isinstance(obj, dict):
# Check if this is a schema with exclusiveMinimum
if "exclusiveMinimum" in obj and isinstance(obj["exclusiveMinimum"], int | float):
# Convert exclusiveMinimum number to minimum
obj["minimum"] = obj["exclusiveMinimum"]
del obj["exclusiveMinimum"]
# Recursively process all values
for value in obj.values():
_fix_exclusive_minimum_in_schema(value)
elif isinstance(obj, list):
# Recursively process all items
for item in obj:
_fix_exclusive_minimum_in_schema(item)
def _fix_string_fields_with_null_defaults(obj: Any) -> None:
"""
Recursively fix string fields that have default: null.
This violates OpenAPI spec - string fields should either have a string default or be optional.
"""
if isinstance(obj, dict):
# Check if this is a field definition with type: string and default: null
if obj.get("type") == "string" and "default" in obj and obj["default"] is None:
# Remove the default: null to make the field optional
del obj["default"]
# Add nullable: true to indicate the field can be null
obj["nullable"] = True
# Recursively process all values
for value in obj.values():
_fix_string_fields_with_null_defaults(value)
elif isinstance(obj, list):
# Recursively process all items
for item in obj:
_fix_string_fields_with_null_defaults(item)
def _fix_anyof_with_null_defaults(obj: Any) -> None:
"""
Recursively fix anyOf schemas that have default: null.
This violates OpenAPI spec - anyOf fields should not have null defaults.
"""
if isinstance(obj, dict):
# Check if this is a field definition with anyOf and default: null
if "anyOf" in obj and "default" in obj and obj["default"] is None:
# Remove the default: null to make the field optional
del obj["default"]
# Add nullable: true to indicate the field can be null
obj["nullable"] = True
# Recursively process all values
for value in obj.values():
_fix_anyof_with_null_defaults(value)
elif isinstance(obj, list):
# Recursively process all items
for item in obj:
_fix_anyof_with_null_defaults(item)
def _fix_all_null_defaults(obj: Any) -> None:
"""
Recursively fix all field types that have default: null.
This violates OpenAPI spec - fields should not have null defaults.
"""
if isinstance(obj, dict):
# Check if this is a field definition with default: null
if "default" in obj and obj["default"] is None:
# Remove the default: null to make the field optional
del obj["default"]
# Add nullable: true to indicate the field can be null
obj["nullable"] = True
# Recursively process all values
for value in obj.values():
_fix_all_null_defaults(value)
elif isinstance(obj, list):
# Recursively process all items
for item in obj:
_fix_all_null_defaults(item)
def _sort_paths_alphabetically(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Sort the paths in the OpenAPI schema by version prefix first, then alphabetically.
Also sort HTTP methods alphabetically within each path.
Version order: v1beta, v1alpha, v1
"""
if "paths" not in openapi_schema:
return openapi_schema
def path_sort_key(path: str) -> tuple:
"""
Create a sort key that groups paths by version prefix first.
Returns (version_priority, path) where version_priority:
- 0 for v1beta
- 1 for v1alpha
- 2 for v1
- 3 for others
"""
if path.startswith("/v1beta/"):
version_priority = 0
elif path.startswith("/v1alpha/"):
version_priority = 1
elif path.startswith("/v1/"):
version_priority = 2
else:
version_priority = 3
return (version_priority, path)
def sort_path_item(path_item: dict[str, Any]) -> dict[str, Any]:
"""Sort HTTP methods alphabetically within a path item."""
if not isinstance(path_item, dict):
return path_item
# Define the order of HTTP methods
method_order = ["delete", "get", "head", "options", "patch", "post", "put", "trace"]
# Create a new ordered dict with methods in alphabetical order
sorted_path_item = {}
# First add methods in the defined order
for method in method_order:
if method in path_item:
sorted_path_item[method] = path_item[method]
# Then add any other keys that aren't HTTP methods
for key, value in path_item.items():
if key not in method_order:
sorted_path_item[key] = value
return sorted_path_item
# Sort paths by version prefix first, then alphabetically
# Also sort HTTP methods within each path
sorted_paths = {}
for path, path_item in sorted(openapi_schema["paths"].items(), key=lambda x: path_sort_key(x[0])):
sorted_paths[path] = sort_path_item(path_item)
openapi_schema["paths"] = sorted_paths
return openapi_schema
def _filter_schema_by_version(
openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True
) -> dict[str, Any]:
"""
Filter OpenAPI schema by API version.
Args:
openapi_schema: The full OpenAPI schema
stable_only: If True, return only /v1/ paths (stable). If False, return only /v1alpha/ and /v1beta/ paths (experimental).
exclude_deprecated: If True, exclude deprecated endpoints from the result.
Returns:
Filtered OpenAPI schema
"""
filtered_schema = openapi_schema.copy()
if "paths" not in filtered_schema:
return filtered_schema
# Filter paths based on version prefix and deprecated status
filtered_paths = {}
for path, path_item in filtered_schema["paths"].items():
# Check if path has any deprecated operations
is_deprecated = _is_path_deprecated(path_item)
# Skip deprecated endpoints if exclude_deprecated is True
if exclude_deprecated and is_deprecated:
continue
if stable_only:
# Only include /v1/ paths, exclude /v1alpha/ and /v1beta/
if path.startswith("/v1/") and not path.startswith("/v1alpha/") and not path.startswith("/v1beta/"):
filtered_paths[path] = path_item
else:
# Only include /v1alpha/ and /v1beta/ paths, exclude /v1/
if path.startswith("/v1alpha/") or path.startswith("/v1beta/"):
filtered_paths[path] = path_item
filtered_schema["paths"] = filtered_paths
# Filter schemas/components to only include ones referenced by filtered paths
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
# Find all schemas that are actually referenced by the filtered paths
# Use the original schema to find all references, not the filtered one
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
# Only keep schemas that are referenced by the filtered paths
filtered_schemas = {}
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
if schema_name in referenced_schemas:
filtered_schemas[schema_name] = schema_def
filtered_schema["components"]["schemas"] = filtered_schemas
# Preserve $defs section if it exists
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
if "components" not in filtered_schema:
filtered_schema["components"] = {}
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
print(f"Preserved $defs section with {len(openapi_schema['components']['$defs'])} items")
else:
print("No $defs section to preserve")
return filtered_schema
def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]:
"""
Find all schemas that are referenced by the filtered paths.
This recursively traverses the path definitions to find all $ref references.
"""
referenced_schemas = set()
# Traverse all filtered paths
for _, path_item in filtered_paths.items():
if not isinstance(path_item, dict):
continue
# Check each HTTP method in the path
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method in path_item:
operation = path_item[method]
if isinstance(operation, dict):
# Find all schema references in this operation
referenced_schemas.update(_find_schema_refs_in_object(operation))
# Also check the responses section for schema references
if "components" in openapi_schema and "responses" in openapi_schema["components"]:
referenced_schemas.update(_find_schema_refs_in_object(openapi_schema["components"]["responses"]))
# Also include schemas that are referenced by other schemas (transitive references)
# This ensures we include all dependencies
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
additional_schemas = set()
for schema_name in referenced_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
# Keep adding transitive references until no new ones are found
while additional_schemas:
new_schemas = additional_schemas - referenced_schemas
if not new_schemas:
break
referenced_schemas.update(new_schemas)
additional_schemas = set()
for schema_name in new_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
return referenced_schemas
def _find_schema_refs_in_object(obj: Any) -> set[str]:
"""
Recursively find all schema references ($ref) in an object.
"""
refs = set()
if isinstance(obj, dict):
for key, value in obj.items():
if key == "$ref" and isinstance(value, str) and value.startswith("#/components/schemas/"):
schema_name = value.split("/")[-1]
refs.add(schema_name)
else:
refs.update(_find_schema_refs_in_object(value))
elif isinstance(obj, list):
for item in obj:
refs.update(_find_schema_refs_in_object(item))
return refs
def _is_path_deprecated(path_item: dict[str, Any]) -> bool:
"""
Check if a path item has any deprecated operations.
"""
if not isinstance(path_item, dict):
return False
# Check each HTTP method in the path item
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method in path_item:
operation = path_item[method]
if isinstance(operation, dict) and operation.get("deprecated", False):
return True
return False
def _filter_deprecated_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Filter OpenAPI schema to include only deprecated endpoints.
Includes all deprecated endpoints regardless of version (v1, v1alpha, v1beta).
"""
filtered_schema = openapi_schema.copy()
if "paths" not in filtered_schema:
return filtered_schema
# Filter paths to only include deprecated ones
filtered_paths = {}
for path, path_item in filtered_schema["paths"].items():
if _is_path_deprecated(path_item):
filtered_paths[path] = path_item
filtered_schema["paths"] = filtered_paths
return filtered_schema
def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Filter OpenAPI schema to include both stable (v1) and experimental (v1alpha, v1beta) APIs.
Excludes deprecated endpoints. This is used for the combined "stainless" spec.
"""
filtered_schema = openapi_schema.copy()
if "paths" not in filtered_schema:
return filtered_schema
# Filter paths to include stable (v1) and experimental (v1alpha, v1beta), excluding deprecated
filtered_paths = {}
for path, path_item in filtered_schema["paths"].items():
# Check if path has any deprecated operations
is_deprecated = _is_path_deprecated(path_item)
# Skip deprecated endpoints
if is_deprecated:
continue
# Include /v1/ paths (stable)
if path.startswith("/v1/") and not path.startswith("/v1alpha/") and not path.startswith("/v1beta/"):
filtered_paths[path] = path_item
# Include /v1alpha/ and /v1beta/ paths (experimental)
elif path.startswith("/v1alpha/") or path.startswith("/v1beta/"):
filtered_paths[path] = path_item
filtered_schema["paths"] = filtered_paths
# Filter schemas/components to only include ones referenced by filtered paths
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
filtered_schemas = {}
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
if schema_name in referenced_schemas:
filtered_schemas[schema_name] = schema_def
filtered_schema["components"]["schemas"] = filtered_schemas
return filtered_schema
def generate_openapi_spec(output_dir: str, format: str = "yaml", include_examples: bool = True) -> dict[str, Any]:
"""
Generate OpenAPI specification using FastAPI's built-in method.
Args:
output_dir: Directory to save the generated files
format: Output format ("yaml", "json", or "both")
include_examples: Whether to include examples in the spec
Returns:
The generated OpenAPI specification as a dictionary
"""
# Create the FastAPI app
app = create_llama_stack_app()
# Generate the OpenAPI schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
servers=app.servers,
)
# Debug: Check if there's a root-level $defs in the original schema
if "$defs" in openapi_schema:
print(f"Original schema has root-level $defs with {len(openapi_schema['$defs'])} items")
else:
print("Original schema has no root-level $defs")
# Add Llama Stack specific extensions
openapi_schema = _add_llama_stack_extensions(openapi_schema, app)
# Add standard error responses
openapi_schema = _add_error_responses(openapi_schema)
# Ensure all @json_schema_type decorated models are included
openapi_schema = _ensure_json_schema_types_included(openapi_schema)
# Fix $ref references to point to components/schemas instead of $defs
openapi_schema = _fix_ref_references(openapi_schema)
# Debug: Check if there are any $ref references to $defs in the schema
defs_refs = []
def find_defs_refs(obj: Any, path: str = "") -> None:
if isinstance(obj, dict):
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
defs_refs.append(f"{path}: {obj['$ref']}")
for key, value in obj.items():
find_defs_refs(value, f"{path}.{key}" if path else key)
elif isinstance(obj, list):
for i, item in enumerate(obj):
find_defs_refs(item, f"{path}[{i}]")
find_defs_refs(openapi_schema)
if defs_refs:
print(f"Found {len(defs_refs)} $ref references to $defs in schema")
for ref in defs_refs[:5]: # Show first 5
print(f" {ref}")
else:
print("No $ref references to $defs found in schema")
# Note: Let Pydantic/FastAPI generate the correct, standards-compliant schema
# Fields with default values should be optional according to OpenAPI standards
# Fix anyOf schemas with type: 'null' to avoid oasdiff errors
openapi_schema = _fix_anyof_with_null(openapi_schema)
# Fix path parameter resolution issues
openapi_schema = _fix_path_parameters(openapi_schema)
# Eliminate $defs section entirely for oasdiff compatibility
openapi_schema = _eliminate_defs_section(openapi_schema)
# Fix component descriptions to only include first line (summary)
openapi_schema = _fix_component_descriptions(openapi_schema)
# Debug: Check if there's a root-level $defs after flattening
if "$defs" in openapi_schema:
print(f"After flattening: root-level $defs with {len(openapi_schema['$defs'])} items")
else:
print("After flattening: no root-level $defs")
# Ensure all referenced schemas are included
# DISABLED: This was using hardcoded schema generation. FastAPI should handle this automatically.
# openapi_schema = _ensure_referenced_schemas(openapi_schema)
# Control schema registration based on @json_schema_type decorator
# Temporarily disabled to fix missing schema issues
# openapi_schema = _control_schema_registration(openapi_schema)
# Fix malformed schemas after all other processing
# DISABLED: This was a hardcoded workaround. Using Pydantic's TypeAdapter instead.
# _fix_malformed_schemas(openapi_schema)
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
import copy
stable_schema = _filter_schema_by_version(copy.deepcopy(openapi_schema), stable_only=True, exclude_deprecated=True)
experimental_schema = _filter_schema_by_version(
copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True
)
deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema))
combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema))
# Update title and description for combined schema
if "info" in combined_schema:
combined_schema["info"]["title"] = "Llama Stack API - Stable & Experimental APIs"
combined_schema["info"]["description"] = (
combined_schema["info"].get("description", "")
+ "\n\n**🔗 COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. "
"Use stable APIs for production deployments and experimental APIs for testing new features."
)
# Sort paths alphabetically for stable (v1 only)
stable_schema = _sort_paths_alphabetically(stable_schema)
# Sort paths by version prefix for experimental (v1beta, v1alpha)
experimental_schema = _sort_paths_alphabetically(experimental_schema)
# Sort paths by version prefix for deprecated
deprecated_schema = _sort_paths_alphabetically(deprecated_schema)
# Sort paths by version prefix for combined (stainless)
combined_schema = _sort_paths_alphabetically(combined_schema)
# Fix schema issues (like exclusiveMinimum -> minimum) for each spec
stable_schema = _fix_schema_issues(stable_schema)
experimental_schema = _fix_schema_issues(experimental_schema)
deprecated_schema = _fix_schema_issues(deprecated_schema)
combined_schema = _fix_schema_issues(combined_schema)
# Validate the schemas
print("\n🔍 Validating generated schemas...")
stable_valid = validate_openapi_schema(stable_schema, "Stable schema")
experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema")
deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema")
combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema")
if not all([stable_valid, experimental_valid, deprecated_valid, combined_valid]):
print("⚠️ Some schemas failed validation, but continuing with generation...")
# Add any custom modifications here if needed
if include_examples:
# Add examples to the schema if needed
pass
# Ensure output directory exists
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save the stable specification
if format in ["yaml", "both"]:
yaml_path = output_path / "llama-stack-spec.yaml"
# Use ruamel.yaml for better YAML formatting
try:
from ruamel.yaml import YAML
yaml_writer = YAML()
yaml_writer.default_flow_style = False
yaml_writer.sort_keys = False
yaml_writer.width = 4096 # Prevent line wrapping
yaml_writer.allow_unicode = True
with open(yaml_path, "w") as f:
yaml_writer.dump(stable_schema, f)
# Post-process the YAML file to remove $defs section and fix references
# Re-read and re-write with ruamel.yaml
with open(yaml_path) as f:
yaml_content = f.read()
if " $defs:" in yaml_content or "#/$defs/" in yaml_content:
print("Post-processing YAML to remove $defs section")
# Use string replacement to fix references directly
if "#/$defs/" in yaml_content:
refs_fixed = yaml_content.count("#/$defs/")
yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/")
print(f"Fixed {refs_fixed} $ref references using string replacement")
# Parse using PyYAML safe_load first to avoid issues with custom types
# This handles block scalars better during post-processing
import yaml as pyyaml
with open(yaml_path) as f:
yaml_data = pyyaml.safe_load(f)
# Move $defs to components/schemas if it exists
if "$defs" in yaml_data:
print(f"Found $defs section with {len(yaml_data['$defs'])} items")
if "components" not in yaml_data:
yaml_data["components"] = {}
if "schemas" not in yaml_data["components"]:
yaml_data["components"]["schemas"] = {}
# Move all $defs to components/schemas
for def_name, def_schema in yaml_data["$defs"].items():
yaml_data["components"]["schemas"][def_name] = def_schema
# Remove the $defs section
del yaml_data["$defs"]
print("Moved $defs to components/schemas")
# Write the modified YAML back with ruamel.yaml
with open(yaml_path, "w") as f:
yaml_writer.dump(yaml_data, f)
print("Updated YAML file")
except ImportError:
# Fallback to standard yaml if ruamel.yaml is not available
with open(yaml_path, "w") as f:
yaml.dump(stable_schema, f, default_flow_style=False, sort_keys=False)
print(f"✅ Generated YAML (stable): {yaml_path}")
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
with open(experimental_yaml_path, "w") as f:
yaml.dump(experimental_schema, f, default_flow_style=False, sort_keys=False)
print(f"✅ Generated YAML (experimental): {experimental_yaml_path}")
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
with open(deprecated_yaml_path, "w") as f:
yaml.dump(deprecated_schema, f, default_flow_style=False, sort_keys=False)
print(f"✅ Generated YAML (deprecated): {deprecated_yaml_path}")
# Generate combined (stainless) spec
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
try:
from ruamel.yaml import YAML
yaml_writer = YAML()
yaml_writer.default_flow_style = False
yaml_writer.sort_keys = False
yaml_writer.width = 4096 # Prevent line wrapping
yaml_writer.allow_unicode = True
with open(stainless_yaml_path, "w") as f:
yaml_writer.dump(combined_schema, f)
except ImportError:
# Fallback to standard yaml if ruamel.yaml is not available
with open(stainless_yaml_path, "w") as f:
yaml.dump(combined_schema, f, default_flow_style=False, sort_keys=False)
print(f"✅ Generated YAML (stainless/combined): {stainless_yaml_path}")
if format in ["json", "both"]:
json_path = output_path / "llama-stack-spec.json"
with open(json_path, "w") as f:
json.dump(stable_schema, f, indent=2)
print(f"✅ Generated JSON (stable): {json_path}")
experimental_json_path = output_path / "experimental-llama-stack-spec.json"
with open(experimental_json_path, "w") as f:
json.dump(experimental_schema, f, indent=2)
print(f"✅ Generated JSON (experimental): {experimental_json_path}")
deprecated_json_path = output_path / "deprecated-llama-stack-spec.json"
with open(deprecated_json_path, "w") as f:
json.dump(deprecated_schema, f, indent=2)
print(f"✅ Generated JSON (deprecated): {deprecated_json_path}")
stainless_json_path = output_path / "stainless-llama-stack-spec.json"
with open(stainless_json_path, "w") as f:
json.dump(combined_schema, f, indent=2)
print(f"✅ Generated JSON (stainless/combined): {stainless_json_path}")
# Generate HTML documentation
html_path = output_path / "llama-stack-spec.html"
generate_html_docs(stable_schema, html_path)
print(f"✅ Generated HTML: {html_path}")
experimental_html_path = output_path / "experimental-llama-stack-spec.html"
generate_html_docs(experimental_schema, experimental_html_path, spec_file="experimental-llama-stack-spec.yaml")
print(f"✅ Generated HTML (experimental): {experimental_html_path}")
deprecated_html_path = output_path / "deprecated-llama-stack-spec.html"
generate_html_docs(deprecated_schema, deprecated_html_path, spec_file="deprecated-llama-stack-spec.yaml")
print(f"✅ Generated HTML (deprecated): {deprecated_html_path}")
stainless_html_path = output_path / "stainless-llama-stack-spec.html"
generate_html_docs(combined_schema, stainless_html_path, spec_file="stainless-llama-stack-spec.yaml")
print(f"✅ Generated HTML (stainless/combined): {stainless_html_path}")
return stable_schema
def generate_html_docs(
openapi_schema: dict[str, Any], output_path: Path, spec_file: str = "llama-stack-spec.yaml"
) -> None:
"""Generate HTML documentation using ReDoc."""
html_template = f"""
<!DOCTYPE html>
<html>
<head>
<title>Llama Stack API Documentation</title>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<style>
body {{ margin: 0; padding: 0; }}
</style>
</head>
<body>
<redoc spec-url='{spec_file}'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@2.0.0/bundles/redoc.standalone.js"></script>
</body>
</html>
""".strip()
with open(output_path, "w") as f:
f.write(html_template + "\n")
def main():
"""Main entry point for the FastAPI OpenAPI generator."""
import argparse
parser = argparse.ArgumentParser(description="Generate OpenAPI specification using FastAPI")
parser.add_argument("output_dir", help="Output directory for generated files")
parser.add_argument("--format", choices=["yaml", "json", "both"], default="both", help="Output format")
parser.add_argument("--no-examples", action="store_true", help="Exclude examples from the specification")
parser.add_argument(
"--validate-only", action="store_true", help="Only validate existing schema files, don't generate new ones"
)
parser.add_argument("--validate-file", help="Validate a specific schema file")
args = parser.parse_args()
# Handle validation-only mode
if args.validate_only or args.validate_file:
if args.validate_file:
# Validate a specific file
file_path = Path(args.validate_file)
if not file_path.exists():
print(f"❌ File not found: {file_path}")
return 1
print(f"🔍 Validating {file_path}...")
is_valid = validate_schema_file(file_path)
return 0 if is_valid else 1
else:
# Validate all schema files in output directory
output_path = Path(args.output_dir)
if not output_path.exists():
print(f"❌ Output directory not found: {output_path}")
return 1
print(f"🔍 Validating all schema files in {output_path}...")
schema_files = (
list(output_path.glob("*.yaml")) + list(output_path.glob("*.yml")) + list(output_path.glob("*.json"))
)
if not schema_files:
print("❌ No schema files found to validate")
return 1
all_valid = True
for schema_file in schema_files:
print(f"\n📄 Validating {schema_file.name}...")
is_valid = validate_schema_file(schema_file)
if not is_valid:
all_valid = False
if all_valid:
print("\n✅ All schema files are valid!")
return 0
else:
print("\n❌ Some schema files failed validation")
return 1
print("🚀 Generating OpenAPI specification using FastAPI...")
print(f"📁 Output directory: {args.output_dir}")
print(f"📄 Format: {args.format}")
try:
openapi_schema = generate_openapi_spec(
output_dir=args.output_dir, format=args.format, include_examples=not args.no_examples
)
print("\n✅ OpenAPI specification generated successfully!")
print(f"📊 Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
print(f"🛣️ Paths: {len(openapi_schema.get('paths', {}))}")
# Count operations
operation_count = 0
for path_info in openapi_schema.get("paths", {}).values():
for method in ["get", "post", "put", "delete", "patch"]:
if method in path_info:
operation_count += 1
print(f"🔧 Operations: {operation_count}")
except Exception as e:
print(f"❌ Error generating OpenAPI specification: {e}")
raise
if __name__ == "__main__":
main()