mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Removes the need for the strong_typing and pyopenapi packages and purely
use Pydantic for schema generation.
Our generator now purely relies on Pydantic and FastAPI, it is available
at `scripts/fastapi_generator.py`, you can run it like so:
```
uv run ./scripts/run_openapi_generator.sh
```
The generator will:
* Generate the deprecated, experimental, stable and combined specs
* Validate all the spec it generates against OpenAPI standards
A few changes in the schema required for oasdiff some updates so I've
made the following ignore rules. The new Pydantic-based generator is
likely more correct and follows OpenAPI standards better than the old
pyopenapi generator. Instead of trying to make the new generator match
the old one's quirks, we should focus on what's actually correct
according to OpenAPI standards.
These are non-critical changes:
* response-property-became-nullable: Backward compatible:
existing non-null values still work, now also accepts null
* response-required-property-removed: oasdiff reports a false
positive because it doesn't resolve $refs inside anyOf; we could use
tool like 'redocly' to flatten the schema to a single file.
* response-property-type-changed: properties are still object
types, but oasdiff doesn't resolve $refs, so it flags the missing
inline type: object even though the referenced schemas define type:
object
* request-property-one-of-removed: These are false positives
caused by schema restructuring (wrapping in anyOf for nullability,
using -Input variants, or simplifying nested oneOf structures)
that don't change the actual API contract - the same data types are
still accepted, just represented differently in the schema.
* request-parameter-enum-value-removed: These are false
positives caused by oasdiff not resolving $refs - the enum values
(asc, desc, assistants, batch) are still present in the referenced
schemas (Order and OpenAIFilePurpose), just represented via schema
references instead of inline enums.
* request-property-enum-value-removed: this is a false positive caused
by oasdiff not resolving $refs - the enum values (llm, embedding,
rerank) are still present in the referenced ModelType schema,
just represented via schema reference instead of inline enums.
* request-property-type-changed: These are schema quality issues
where type information is missing (due to Any fallback in dynamic
model creation), but the API contract remains unchanged -
properties still exist with correct names and defaults, so the same
requests will work.
* response-body-type-changed: These are false positives caused
by schema representation changes (from inferred/empty types to
explicit $ref schemas, or vice versa) - the actual response types
an API contract remain unchanged, just how they're represented in the
OpenAPI spec.
* response-media-type-removed: This is a false positive caused
by FastAPI's OpenAPI generator not documenting union return types with
AsyncIterator - the streaming functionality with text/event-stream
media type still works when stream=True is passed, it's just not
reflected in the generated OpenAPI spec.
* request-body-type-changed: This is a schema correction - the
old spec incorrectly represented the request body as an object, but
the function signature shows chunks: list[Chunk], so the new spec
correctly shows it as an array, matching the actual API
implementation.
Signed-off-by: Sébastien Han <seb@redhat.com>
1591 lines
63 KiB
Python
Executable file
1591 lines
63 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
FastAPI-based OpenAPI generator for Llama Stack.
|
|
"""
|
|
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import pkgutil
|
|
from pathlib import Path
|
|
from typing import Annotated, Any, get_args, get_origin
|
|
|
|
import yaml
|
|
from fastapi import FastAPI
|
|
from fastapi.openapi.utils import get_openapi
|
|
from openapi_spec_validator import validate_spec
|
|
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
|
|
|
from llama_stack.apis.datatypes import Api
|
|
from llama_stack.apis.version import (
|
|
LLAMA_STACK_API_V1,
|
|
LLAMA_STACK_API_V1ALPHA,
|
|
LLAMA_STACK_API_V1BETA,
|
|
)
|
|
from llama_stack.core.resolver import api_protocol_map
|
|
|
|
# Global list to store dynamic models created during endpoint generation
|
|
_dynamic_models = []
|
|
|
|
|
|
# Cache for protocol methods to avoid repeated lookups
|
|
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
|
|
|
|
|
|
def create_llama_stack_app() -> FastAPI:
|
|
"""
|
|
Create a FastAPI app that represents the Llama Stack API.
|
|
This uses the existing route discovery system to automatically find all routes.
|
|
"""
|
|
app = FastAPI(
|
|
title="Llama Stack API",
|
|
description="A comprehensive API for building and deploying AI applications",
|
|
version="1.0.0",
|
|
servers=[
|
|
{"url": "http://any-hosted-llama-stack.com"},
|
|
],
|
|
)
|
|
|
|
# Get all API routes
|
|
from llama_stack.core.server.routes import get_all_api_routes
|
|
|
|
api_routes = get_all_api_routes()
|
|
|
|
# Create FastAPI routes from the discovered routes
|
|
for api, routes in api_routes.items():
|
|
for route, webmethod in routes:
|
|
# Convert the route to a FastAPI endpoint
|
|
_create_fastapi_endpoint(app, route, webmethod, api)
|
|
|
|
return app
|
|
|
|
|
|
def _get_protocol_method(api: Api, method_name: str) -> Any | None:
|
|
"""
|
|
Get a protocol method function by API and method name.
|
|
Uses caching to avoid repeated lookups.
|
|
|
|
Args:
|
|
api: The API enum
|
|
method_name: The method name (function name)
|
|
|
|
Returns:
|
|
The function object, or None if not found
|
|
"""
|
|
global _protocol_methods_cache
|
|
|
|
if _protocol_methods_cache is None:
|
|
_protocol_methods_cache = {}
|
|
protocols = api_protocol_map()
|
|
from llama_stack.apis.tools import SpecialToolGroup, ToolRuntime
|
|
|
|
toolgroup_protocols = {
|
|
SpecialToolGroup.rag_tool: ToolRuntime,
|
|
}
|
|
|
|
for api_key, protocol in protocols.items():
|
|
method_map: dict[str, Any] = {}
|
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
|
for name, method in protocol_methods:
|
|
method_map[name] = method
|
|
|
|
# Handle tool_runtime special case
|
|
if api_key == Api.tool_runtime:
|
|
for tool_group, sub_protocol in toolgroup_protocols.items():
|
|
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
|
for name, method in sub_protocol_methods:
|
|
if hasattr(method, "__webmethod__"):
|
|
method_map[f"{tool_group.value}.{name}"] = method
|
|
|
|
_protocol_methods_cache[api_key] = method_map
|
|
|
|
return _protocol_methods_cache.get(api, {}).get(method_name)
|
|
|
|
|
|
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
|
|
"""Extract path parameters from a URL path and return them as OpenAPI parameter definitions."""
|
|
import re
|
|
|
|
matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", path)
|
|
return [
|
|
{
|
|
"name": param_name,
|
|
"in": "path",
|
|
"required": True,
|
|
"schema": {"type": "string"},
|
|
"description": f"Path parameter: {param_name}",
|
|
}
|
|
for param_name in matches
|
|
]
|
|
|
|
|
|
def _create_endpoint_with_request_model(
|
|
request_model: type, response_model: type | None, operation_description: str | None
|
|
):
|
|
"""Create an endpoint function with a request body model."""
|
|
|
|
async def endpoint(request: request_model) -> response_model:
|
|
return response_model() if response_model else {}
|
|
|
|
if operation_description:
|
|
endpoint.__doc__ = operation_description
|
|
return endpoint
|
|
|
|
|
|
def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_any: bool = False) -> dict[str, tuple]:
|
|
"""Build field definitions for a Pydantic model from query parameters."""
|
|
from typing import Any
|
|
|
|
from pydantic import Field
|
|
|
|
field_definitions = {}
|
|
for param_name, param_type, default_value in query_parameters:
|
|
if use_any:
|
|
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
|
continue
|
|
|
|
base_type = param_type
|
|
extracted_field = None
|
|
if get_origin(param_type) is Annotated:
|
|
args = get_args(param_type)
|
|
if args:
|
|
base_type = args[0]
|
|
for arg in args[1:]:
|
|
if isinstance(arg, Field):
|
|
extracted_field = arg
|
|
break
|
|
|
|
try:
|
|
if extracted_field:
|
|
field_definitions[param_name] = (base_type, extracted_field)
|
|
else:
|
|
field_definitions[param_name] = (
|
|
base_type,
|
|
... if default_value is inspect.Parameter.empty else default_value,
|
|
)
|
|
except (TypeError, ValueError):
|
|
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
|
|
|
# Ensure all parameters are included
|
|
expected_params = {name for name, _, _ in query_parameters}
|
|
missing = expected_params - set(field_definitions.keys())
|
|
if missing:
|
|
for param_name, _, default_value in query_parameters:
|
|
if param_name in missing:
|
|
field_definitions[param_name] = (
|
|
Any,
|
|
... if default_value is inspect.Parameter.empty else default_value,
|
|
)
|
|
|
|
return field_definitions
|
|
|
|
|
|
def _create_dynamic_request_model(
|
|
webmethod, query_parameters: list[tuple[str, type, Any]], use_any: bool = False, add_uuid: bool = False
|
|
) -> type | None:
|
|
"""Create a dynamic Pydantic model for request body."""
|
|
import uuid
|
|
|
|
from pydantic import create_model
|
|
|
|
try:
|
|
field_definitions = _build_field_definitions(query_parameters, use_any)
|
|
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
|
|
model_name = f"{clean_route}_Request"
|
|
if add_uuid:
|
|
model_name = f"{model_name}_{uuid.uuid4().hex[:8]}"
|
|
|
|
request_model = create_model(model_name, **field_definitions)
|
|
_dynamic_models.append(request_model)
|
|
return request_model
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _build_signature_params(
|
|
query_parameters: list[tuple[str, type, Any]],
|
|
) -> tuple[list[inspect.Parameter], dict[str, type]]:
|
|
"""Build signature parameters and annotations from query parameters."""
|
|
signature_params = []
|
|
param_annotations = {}
|
|
for param_name, param_type, default_value in query_parameters:
|
|
param_annotations[param_name] = param_type
|
|
signature_params.append(
|
|
inspect.Parameter(
|
|
param_name,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
|
|
annotation=param_type,
|
|
)
|
|
)
|
|
return signature_params, param_annotations
|
|
|
|
|
|
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|
"""Create a FastAPI endpoint from a discovered route and webmethod."""
|
|
path = route.path
|
|
methods = route.methods
|
|
name = route.name
|
|
fastapi_path = path.replace("{", "{").replace("}", "}")
|
|
|
|
request_model, response_model, query_parameters, file_form_params = _find_models_for_endpoint(webmethod, api, name)
|
|
operation_description = _extract_operation_description_from_docstring(api, name)
|
|
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
|
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
|
|
|
if file_form_params and is_post_put:
|
|
signature_params = list(file_form_params)
|
|
param_annotations = {param.name: param.annotation for param in file_form_params}
|
|
for param_name, param_type, default_value in query_parameters:
|
|
signature_params.append(
|
|
inspect.Parameter(
|
|
param_name,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
|
|
annotation=param_type,
|
|
)
|
|
)
|
|
param_annotations[param_name] = param_type
|
|
|
|
async def file_form_endpoint():
|
|
return response_model() if response_model else {}
|
|
|
|
if operation_description:
|
|
file_form_endpoint.__doc__ = operation_description
|
|
file_form_endpoint.__signature__ = inspect.Signature(signature_params)
|
|
file_form_endpoint.__annotations__ = param_annotations
|
|
endpoint_func = file_form_endpoint
|
|
elif request_model and response_model:
|
|
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
|
|
elif response_model and query_parameters:
|
|
if is_post_put:
|
|
# Try creating request model with type preservation, fallback to Any, then minimal
|
|
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=False)
|
|
if not request_model:
|
|
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True)
|
|
if not request_model:
|
|
request_model = _create_dynamic_request_model(webmethod, query_parameters, use_any=True, add_uuid=True)
|
|
|
|
if request_model:
|
|
endpoint_func = _create_endpoint_with_request_model(
|
|
request_model, response_model, operation_description
|
|
)
|
|
else:
|
|
|
|
async def empty_endpoint() -> response_model:
|
|
return response_model() if response_model else {}
|
|
|
|
if operation_description:
|
|
empty_endpoint.__doc__ = operation_description
|
|
endpoint_func = empty_endpoint
|
|
else:
|
|
sorted_params = sorted(query_parameters, key=lambda x: (x[2] is not inspect.Parameter.empty, x[0]))
|
|
signature_params, param_annotations = _build_signature_params(sorted_params)
|
|
|
|
async def query_endpoint():
|
|
return response_model()
|
|
|
|
if operation_description:
|
|
query_endpoint.__doc__ = operation_description
|
|
query_endpoint.__signature__ = inspect.Signature(signature_params)
|
|
query_endpoint.__annotations__ = param_annotations
|
|
endpoint_func = query_endpoint
|
|
elif response_model:
|
|
|
|
async def response_only_endpoint() -> response_model:
|
|
return response_model()
|
|
|
|
if operation_description:
|
|
response_only_endpoint.__doc__ = operation_description
|
|
endpoint_func = response_only_endpoint
|
|
elif query_parameters:
|
|
signature_params, param_annotations = _build_signature_params(query_parameters)
|
|
|
|
async def params_only_endpoint():
|
|
return {}
|
|
|
|
if operation_description:
|
|
params_only_endpoint.__doc__ = operation_description
|
|
params_only_endpoint.__signature__ = inspect.Signature(signature_params)
|
|
params_only_endpoint.__annotations__ = param_annotations
|
|
endpoint_func = params_only_endpoint
|
|
else:
|
|
|
|
async def no_params_endpoint():
|
|
return {}
|
|
|
|
if operation_description:
|
|
no_params_endpoint.__doc__ = operation_description
|
|
endpoint_func = no_params_endpoint
|
|
|
|
# Add the endpoint to the FastAPI app
|
|
is_deprecated = webmethod.deprecated or False
|
|
route_kwargs = {
|
|
"name": name,
|
|
"tags": [_get_tag_from_api(api)],
|
|
"deprecated": is_deprecated,
|
|
"responses": {
|
|
200: {
|
|
"description": response_description,
|
|
"content": {
|
|
"application/json": {
|
|
"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {}
|
|
}
|
|
},
|
|
},
|
|
400: {"$ref": "#/components/responses/BadRequest400"},
|
|
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
|
500: {"$ref": "#/components/responses/InternalServerError500"},
|
|
"default": {"$ref": "#/components/responses/DefaultError"},
|
|
},
|
|
}
|
|
|
|
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
|
|
for method in methods:
|
|
if handler := method_map.get(method.upper()):
|
|
handler(fastapi_path, **route_kwargs)(endpoint_func)
|
|
|
|
|
|
def _extract_operation_description_from_docstring(api: Api, method_name: str) -> str | None:
|
|
"""Extract operation description from the actual function docstring."""
|
|
func = _get_protocol_method(api, method_name)
|
|
if not func or not func.__doc__:
|
|
return None
|
|
|
|
doc_lines = func.__doc__.split("\n")
|
|
description_lines = []
|
|
metadata_markers = (":param", ":type", ":return", ":returns", ":raises", ":exception", ":yield", ":yields", ":cvar")
|
|
|
|
for line in doc_lines:
|
|
if line.strip().startswith(metadata_markers):
|
|
break
|
|
description_lines.append(line)
|
|
|
|
description = "\n".join(description_lines).strip()
|
|
return description if description else None
|
|
|
|
|
|
def _extract_response_description_from_docstring(webmethod, response_model, api: Api, method_name: str) -> str:
|
|
"""Extract response description from the actual function docstring."""
|
|
func = _get_protocol_method(api, method_name)
|
|
if not func or not func.__doc__:
|
|
return "Successful Response"
|
|
for line in func.__doc__.split("\n"):
|
|
if line.strip().startswith(":returns:"):
|
|
if desc := line.strip()[9:].strip():
|
|
return desc
|
|
return "Successful Response"
|
|
|
|
|
|
def _get_tag_from_api(api: Api) -> str:
|
|
"""Extract a tag name from the API enum for API grouping."""
|
|
return api.value.replace("_", " ").title()
|
|
|
|
|
|
def _is_file_or_form_param(param_type: Any) -> bool:
|
|
"""Check if a parameter type is annotated with File() or Form()."""
|
|
if get_origin(param_type) is Annotated:
|
|
args = get_args(param_type)
|
|
if len(args) > 1:
|
|
# Check metadata for File or Form
|
|
for metadata in args[1:]:
|
|
# Check if it's a File or Form instance
|
|
if hasattr(metadata, "__class__"):
|
|
class_name = metadata.__class__.__name__
|
|
if class_name in ("File", "Form"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _find_models_for_endpoint(
|
|
webmethod, api: Api, method_name: str
|
|
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]:
|
|
"""
|
|
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
|
This uses the protocol function to determine the correct models dynamically.
|
|
|
|
Args:
|
|
webmethod: The webmethod metadata
|
|
api: The API enum for looking up the function
|
|
method_name: The method name (function name)
|
|
|
|
Returns:
|
|
tuple: (request_model, response_model, query_parameters, file_form_params)
|
|
where query_parameters is a list of (name, type, default_value) tuples
|
|
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
|
"""
|
|
try:
|
|
# Get the function from the protocol
|
|
func = _get_protocol_method(api, method_name)
|
|
if not func:
|
|
return None, None, [], []
|
|
|
|
# Analyze the function signature
|
|
sig = inspect.signature(func)
|
|
|
|
# Find request model and collect all body parameters
|
|
request_model = None
|
|
query_parameters = []
|
|
file_form_params = []
|
|
path_params = set()
|
|
|
|
# Extract path parameters from the route
|
|
if webmethod and hasattr(webmethod, "route"):
|
|
import re
|
|
|
|
path_matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", webmethod.route)
|
|
path_params = set(path_matches)
|
|
|
|
for param_name, param in sig.parameters.items():
|
|
if param_name == "self":
|
|
continue
|
|
|
|
# Skip *args and **kwargs parameters - these are not real API parameters
|
|
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
continue
|
|
|
|
# Check if this is a path parameter
|
|
if param_name in path_params:
|
|
# Path parameters are handled separately, skip them
|
|
continue
|
|
|
|
# Check if it's a File() or Form() parameter - these need special handling
|
|
param_type = param.annotation
|
|
if _is_file_or_form_param(param_type):
|
|
# File() and Form() parameters must be in the function signature directly
|
|
# They cannot be part of a Pydantic model
|
|
file_form_params.append(param)
|
|
continue
|
|
|
|
# Check if it's a Pydantic model (for POST/PUT requests)
|
|
if hasattr(param_type, "model_json_schema"):
|
|
# Collect all body parameters including Pydantic models
|
|
# We'll decide later whether to use a single model or create a combined one
|
|
query_parameters.append((param_name, param_type, param.default))
|
|
elif get_origin(param_type) is Annotated:
|
|
# Handle Annotated types - get the base type
|
|
args = get_args(param_type)
|
|
if args and hasattr(args[0], "model_json_schema"):
|
|
# Collect Pydantic models from Annotated types
|
|
query_parameters.append((param_name, args[0], param.default))
|
|
else:
|
|
# Regular annotated parameter (but not File/Form, already handled above)
|
|
query_parameters.append((param_name, param_type, param.default))
|
|
else:
|
|
# This is likely a body parameter for POST/PUT or query parameter for GET
|
|
# Store the parameter info for later use
|
|
# Preserve inspect.Parameter.empty to distinguish "no default" from "default=None"
|
|
default_value = param.default
|
|
|
|
# Extract the base type from union types (e.g., str | None -> str)
|
|
# Also make it safe for FastAPI to avoid forward reference issues
|
|
query_parameters.append((param_name, param_type, default_value))
|
|
|
|
# If there's exactly one body parameter and it's a Pydantic model, use it directly
|
|
# Otherwise, we'll create a combined request model from all parameters
|
|
if len(query_parameters) == 1:
|
|
param_name, param_type, default_value = query_parameters[0]
|
|
if hasattr(param_type, "model_json_schema"):
|
|
request_model = param_type
|
|
query_parameters = [] # Clear query_parameters so we use the single model
|
|
|
|
# Find response model from return annotation
|
|
response_model = None
|
|
return_annotation = sig.return_annotation
|
|
if return_annotation != inspect.Signature.empty:
|
|
if hasattr(return_annotation, "model_json_schema"):
|
|
response_model = return_annotation
|
|
elif get_origin(return_annotation) is Annotated:
|
|
# Handle Annotated return types
|
|
args = get_args(return_annotation)
|
|
if args:
|
|
# Check if the first argument is a Pydantic model
|
|
if hasattr(args[0], "model_json_schema"):
|
|
response_model = args[0]
|
|
# Check if the first argument is a union type
|
|
elif get_origin(args[0]) is type(args[0]): # Union type
|
|
union_args = get_args(args[0])
|
|
for arg in union_args:
|
|
if hasattr(arg, "model_json_schema"):
|
|
response_model = arg
|
|
break
|
|
elif get_origin(return_annotation) is type(return_annotation): # Union type
|
|
# Handle union types - try to find the first Pydantic model
|
|
args = get_args(return_annotation)
|
|
for arg in args:
|
|
if hasattr(arg, "model_json_schema"):
|
|
response_model = arg
|
|
break
|
|
|
|
return request_model, response_model, query_parameters, file_form_params
|
|
|
|
except Exception:
|
|
# If we can't analyze the function signature, return None
|
|
return None, None, [], []
|
|
|
|
|
|
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
|
"""Ensure components.schemas exists in the schema."""
|
|
if "components" not in openapi_schema:
|
|
openapi_schema["components"] = {}
|
|
if "schemas" not in openapi_schema["components"]:
|
|
openapi_schema["components"]["schemas"] = {}
|
|
|
|
|
|
def _import_all_modules_in_package(package_name: str) -> list[Any]:
|
|
"""
|
|
Dynamically import all modules in a package to trigger register_schema calls.
|
|
|
|
This walks through all modules in the package and imports them, ensuring
|
|
that any register_schema() calls at module level are executed.
|
|
|
|
Args:
|
|
package_name: The fully qualified package name (e.g., 'llama_stack.apis')
|
|
|
|
Returns:
|
|
List of imported module objects
|
|
"""
|
|
modules = []
|
|
try:
|
|
package = importlib.import_module(package_name)
|
|
except ImportError:
|
|
return modules
|
|
|
|
package_path = getattr(package, "__path__", None)
|
|
if not package_path:
|
|
return modules
|
|
|
|
# Walk packages and modules recursively
|
|
for _, modname, ispkg in pkgutil.walk_packages(package_path, prefix=f"{package_name}."):
|
|
if not modname.startswith("_"):
|
|
try:
|
|
module = importlib.import_module(modname)
|
|
modules.append(module)
|
|
|
|
# If this is a package, also try to import any .py files directly
|
|
# (e.g., llama_stack.apis.scoring_functions.scoring_functions)
|
|
if ispkg:
|
|
try:
|
|
# Try importing the module file with the same name as the package
|
|
# This handles cases like scoring_functions/scoring_functions.py
|
|
module_file_name = f"{modname}.{modname.split('.')[-1]}"
|
|
module_file = importlib.import_module(module_file_name)
|
|
if module_file not in modules:
|
|
modules.append(module_file)
|
|
except (ImportError, AttributeError, TypeError):
|
|
# It's okay if this fails - not all packages have a module file with the same name
|
|
pass
|
|
except (ImportError, AttributeError, TypeError):
|
|
# Skip modules that can't be imported (e.g., missing dependencies)
|
|
continue
|
|
|
|
return modules
|
|
|
|
|
|
def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]) -> None:
|
|
"""
|
|
Extract $defs from a schema, move them to components/schemas, and fix references.
|
|
This handles both TypeAdapter-generated schemas and model_json_schema() schemas.
|
|
"""
|
|
if "$defs" in schema:
|
|
defs = schema.pop("$defs")
|
|
for def_name, def_schema in defs.items():
|
|
if def_name not in openapi_schema["components"]["schemas"]:
|
|
openapi_schema["components"]["schemas"][def_name] = def_schema
|
|
# Recursively handle $defs in nested schemas
|
|
_extract_and_fix_defs(def_schema, openapi_schema)
|
|
|
|
# Fix any references in the main schema that point to $defs
|
|
def fix_refs_in_schema(obj: Any) -> None:
|
|
if isinstance(obj, dict):
|
|
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
|
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
|
for value in obj.values():
|
|
fix_refs_in_schema(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
fix_refs_in_schema(item)
|
|
|
|
fix_refs_in_schema(schema)
|
|
|
|
|
|
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Ensure all @json_schema_type decorated models and registered schemas are included in the OpenAPI schema.
|
|
This finds all models with the _llama_stack_schema_type attribute and schemas registered via register_schema.
|
|
"""
|
|
_ensure_components_schemas(openapi_schema)
|
|
|
|
# Import TypeAdapter for handling union types and other non-model types
|
|
from pydantic import TypeAdapter
|
|
|
|
# Dynamically import all modules in packages that might register schemas
|
|
# This ensures register_schema() calls execute and populate _registered_schemas
|
|
# Also collect the modules for later scanning of @json_schema_type decorated classes
|
|
apis_modules = _import_all_modules_in_package("llama_stack.apis")
|
|
_import_all_modules_in_package("llama_stack.core.telemetry")
|
|
|
|
# First, handle registered schemas (union types, etc.)
|
|
from llama_stack.schema_utils import _registered_schemas
|
|
|
|
for schema_type, registration_info in _registered_schemas.items():
|
|
schema_name = registration_info["name"]
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
try:
|
|
# Use TypeAdapter for union types and other non-model types
|
|
# Use ref_template to generate references in the format we need
|
|
adapter = TypeAdapter(schema_type)
|
|
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
|
|
|
# Extract and fix $defs if present
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception as e:
|
|
# Skip if we can't generate the schema
|
|
print(f"Warning: Failed to generate schema for registered type {schema_name}: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
# Find all classes with the _llama_stack_schema_type attribute
|
|
# Use the modules we already imported above
|
|
for module in apis_modules:
|
|
for attr_name in dir(module):
|
|
try:
|
|
attr = getattr(module, attr_name)
|
|
if (
|
|
hasattr(attr, "_llama_stack_schema_type")
|
|
and hasattr(attr, "model_json_schema")
|
|
and hasattr(attr, "__name__")
|
|
):
|
|
schema_name = attr.__name__
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
try:
|
|
# Use ref_template to ensure consistent reference format and $defs handling
|
|
schema = attr.model_json_schema(ref_template="#/components/schemas/{model}")
|
|
# Extract and fix $defs if present (model_json_schema can also generate $defs)
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception as e:
|
|
# Skip if we can't generate the schema
|
|
print(f"Warning: Failed to generate schema for {schema_name}: {e}")
|
|
continue
|
|
except (AttributeError, TypeError):
|
|
continue
|
|
|
|
# Also include any dynamic models that were created during endpoint generation
|
|
# This is a workaround to ensure dynamic models appear in the schema
|
|
global _dynamic_models
|
|
if "_dynamic_models" in globals():
|
|
for model in _dynamic_models:
|
|
try:
|
|
schema_name = model.__name__
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
|
# Extract and fix $defs if present
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception:
|
|
# Skip if we can't generate the schema
|
|
continue
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _fix_ref_references(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Fix $ref references to point to components/schemas instead of $defs.
|
|
This prevents the YAML dumper from creating a root-level $defs section.
|
|
"""
|
|
|
|
def fix_refs(obj: Any) -> None:
|
|
if isinstance(obj, dict):
|
|
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
|
# Replace #/$defs/ with #/components/schemas/
|
|
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
|
for value in obj.values():
|
|
fix_refs(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
fix_refs(item)
|
|
|
|
fix_refs(openapi_schema)
|
|
return openapi_schema
|
|
|
|
|
|
def _eliminate_defs_section(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Eliminate $defs section entirely by moving all definitions to components/schemas.
|
|
This matches the structure of the old pyopenapi generator for oasdiff compatibility.
|
|
"""
|
|
_ensure_components_schemas(openapi_schema)
|
|
|
|
# First pass: collect all $defs from anywhere in the schema
|
|
defs_to_move = {}
|
|
|
|
def collect_defs(obj: Any) -> None:
|
|
if isinstance(obj, dict):
|
|
if "$defs" in obj:
|
|
# Collect $defs for later processing
|
|
for def_name, def_schema in obj["$defs"].items():
|
|
if def_name not in defs_to_move:
|
|
defs_to_move[def_name] = def_schema
|
|
|
|
# Recursively process all values
|
|
for value in obj.values():
|
|
collect_defs(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
collect_defs(item)
|
|
|
|
# Collect all $defs
|
|
collect_defs(openapi_schema)
|
|
|
|
# Move all $defs to components/schemas
|
|
for def_name, def_schema in defs_to_move.items():
|
|
if def_name not in openapi_schema["components"]["schemas"]:
|
|
openapi_schema["components"]["schemas"][def_name] = def_schema
|
|
|
|
# Also move any existing root-level $defs to components/schemas
|
|
if "$defs" in openapi_schema:
|
|
print(f"Found root-level $defs with {len(openapi_schema['$defs'])} items, moving to components/schemas")
|
|
for def_name, def_schema in openapi_schema["$defs"].items():
|
|
if def_name not in openapi_schema["components"]["schemas"]:
|
|
openapi_schema["components"]["schemas"][def_name] = def_schema
|
|
# Remove the root-level $defs
|
|
del openapi_schema["$defs"]
|
|
|
|
# Second pass: remove all $defs sections from anywhere in the schema
|
|
def remove_defs(obj: Any) -> None:
|
|
if isinstance(obj, dict):
|
|
if "$defs" in obj:
|
|
del obj["$defs"]
|
|
|
|
# Recursively process all values
|
|
for value in obj.values():
|
|
remove_defs(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
remove_defs(item)
|
|
|
|
# Remove all $defs sections
|
|
remove_defs(openapi_schema)
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Add standard error response definitions to the OpenAPI schema.
|
|
Uses the actual Error model from the codebase for consistency.
|
|
"""
|
|
if "components" not in openapi_schema:
|
|
openapi_schema["components"] = {}
|
|
if "responses" not in openapi_schema["components"]:
|
|
openapi_schema["components"]["responses"] = {}
|
|
|
|
try:
|
|
from llama_stack.apis.datatypes import Error
|
|
|
|
_ensure_components_schemas(openapi_schema)
|
|
if "Error" not in openapi_schema["components"]["schemas"]:
|
|
openapi_schema["components"]["schemas"]["Error"] = Error.model_json_schema()
|
|
except ImportError:
|
|
pass
|
|
|
|
# Define standard HTTP error responses
|
|
error_responses = {
|
|
400: {
|
|
"name": "BadRequest400",
|
|
"description": "The request was invalid or malformed",
|
|
"example": {"status": 400, "title": "Bad Request", "detail": "The request was invalid or malformed"},
|
|
},
|
|
429: {
|
|
"name": "TooManyRequests429",
|
|
"description": "The client has sent too many requests in a given amount of time",
|
|
"example": {
|
|
"status": 429,
|
|
"title": "Too Many Requests",
|
|
"detail": "You have exceeded the rate limit. Please try again later.",
|
|
},
|
|
},
|
|
500: {
|
|
"name": "InternalServerError500",
|
|
"description": "The server encountered an unexpected error",
|
|
"example": {"status": 500, "title": "Internal Server Error", "detail": "An unexpected error occurred"},
|
|
},
|
|
}
|
|
|
|
# Add each error response to the schema
|
|
for _, error_info in error_responses.items():
|
|
response_name = error_info["name"]
|
|
openapi_schema["components"]["responses"][response_name] = {
|
|
"description": error_info["description"],
|
|
"content": {
|
|
"application/json": {"schema": {"$ref": "#/components/schemas/Error"}, "example": error_info["example"]}
|
|
},
|
|
}
|
|
|
|
# Add a default error response
|
|
openapi_schema["components"]["responses"]["DefaultError"] = {
|
|
"description": "An error occurred",
|
|
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Error"}}},
|
|
}
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _fix_path_parameters(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Fix path parameter resolution issues by adding explicit parameter definitions.
|
|
"""
|
|
if "paths" not in openapi_schema:
|
|
return openapi_schema
|
|
|
|
for path, path_item in openapi_schema["paths"].items():
|
|
# Extract path parameters from the URL
|
|
path_params = _extract_path_parameters(path)
|
|
|
|
if not path_params:
|
|
continue
|
|
|
|
# Add parameters to each operation in this path
|
|
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
|
if method in path_item and isinstance(path_item[method], dict):
|
|
operation = path_item[method]
|
|
if "parameters" not in operation:
|
|
operation["parameters"] = []
|
|
|
|
# Add path parameters that aren't already defined
|
|
existing_param_names = {p.get("name") for p in operation["parameters"] if p.get("in") == "path"}
|
|
for param in path_params:
|
|
if param["name"] not in existing_param_names:
|
|
operation["parameters"].append(param)
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""Fix common schema issues: exclusiveMinimum and null defaults."""
|
|
if "components" in openapi_schema and "schemas" in openapi_schema["components"]:
|
|
for schema_def in openapi_schema["components"]["schemas"].values():
|
|
_fix_schema_recursive(schema_def)
|
|
return openapi_schema
|
|
|
|
|
|
def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI schema") -> bool:
|
|
"""
|
|
Validate an OpenAPI schema using openapi-spec-validator.
|
|
|
|
Args:
|
|
schema: The OpenAPI schema dictionary to validate
|
|
schema_name: Name of the schema for error reporting
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
|
|
Raises:
|
|
OpenAPIValidationError: If validation fails
|
|
"""
|
|
try:
|
|
validate_spec(schema)
|
|
print(f"✅ {schema_name} is valid")
|
|
return True
|
|
except OpenAPISpecValidatorError as e:
|
|
print(f"❌ {schema_name} validation failed:")
|
|
print(f" {e}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"❌ {schema_name} validation error: {e}")
|
|
return False
|
|
|
|
|
|
def validate_schema_file(file_path: Path) -> bool:
|
|
"""
|
|
Validate an OpenAPI schema file (YAML or JSON).
|
|
|
|
Args:
|
|
file_path: Path to the schema file
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
try:
|
|
with open(file_path) as f:
|
|
if file_path.suffix.lower() in [".yaml", ".yml"]:
|
|
schema = yaml.safe_load(f)
|
|
elif file_path.suffix.lower() == ".json":
|
|
schema = json.load(f)
|
|
else:
|
|
print(f"❌ Unsupported file format: {file_path.suffix}")
|
|
return False
|
|
|
|
return validate_openapi_schema(schema, str(file_path))
|
|
except Exception as e:
|
|
print(f"❌ Failed to read {file_path}: {e}")
|
|
return False
|
|
|
|
|
|
def _fix_schema_recursive(obj: Any) -> None:
|
|
"""Recursively fix schema issues: exclusiveMinimum and null defaults."""
|
|
if isinstance(obj, dict):
|
|
if "exclusiveMinimum" in obj and isinstance(obj["exclusiveMinimum"], int | float):
|
|
obj["minimum"] = obj.pop("exclusiveMinimum")
|
|
if "default" in obj and obj["default"] is None:
|
|
del obj["default"]
|
|
obj["nullable"] = True
|
|
for value in obj.values():
|
|
_fix_schema_recursive(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
_fix_schema_recursive(item)
|
|
|
|
|
|
def _clean_description(description: str) -> str:
|
|
"""Remove :param, :type, :returns, and other docstring metadata from description."""
|
|
if not description:
|
|
return description
|
|
|
|
lines = description.split("\n")
|
|
cleaned_lines = []
|
|
skip_until_empty = False
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
# Skip lines that start with docstring metadata markers
|
|
if stripped.startswith(
|
|
(":param", ":type", ":return", ":returns", ":raises", ":exception", ":yield", ":yields", ":cvar")
|
|
):
|
|
skip_until_empty = True
|
|
continue
|
|
# If we're skipping and hit an empty line, resume normal processing
|
|
if skip_until_empty:
|
|
if not stripped:
|
|
skip_until_empty = False
|
|
continue
|
|
# Include the line if we're not skipping
|
|
cleaned_lines.append(line)
|
|
|
|
# Join and strip trailing whitespace
|
|
result = "\n".join(cleaned_lines).strip()
|
|
return result
|
|
|
|
|
|
def _clean_schema_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""Clean descriptions in schema definitions by removing docstring metadata."""
|
|
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
|
|
return openapi_schema
|
|
|
|
schemas = openapi_schema["components"]["schemas"]
|
|
for schema_def in schemas.values():
|
|
if isinstance(schema_def, dict) and "description" in schema_def and isinstance(schema_def["description"], str):
|
|
schema_def["description"] = _clean_description(schema_def["description"])
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _remove_query_params_from_body_endpoints(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Remove query parameters from POST/PUT/PATCH endpoints that have a request body.
|
|
FastAPI sometimes infers parameters as query params even when they should be in the request body.
|
|
"""
|
|
if "paths" not in openapi_schema:
|
|
return openapi_schema
|
|
|
|
body_methods = {"post", "put", "patch"}
|
|
|
|
for _path, path_item in openapi_schema["paths"].items():
|
|
if not isinstance(path_item, dict):
|
|
continue
|
|
|
|
for method in body_methods:
|
|
if method not in path_item:
|
|
continue
|
|
|
|
operation = path_item[method]
|
|
if not isinstance(operation, dict):
|
|
continue
|
|
|
|
# Check if this operation has a request body
|
|
has_request_body = "requestBody" in operation and operation["requestBody"]
|
|
|
|
if has_request_body:
|
|
# Remove all query parameters (parameters with "in": "query")
|
|
if "parameters" in operation:
|
|
# Filter out query parameters, keep path and header parameters
|
|
operation["parameters"] = [
|
|
param
|
|
for param in operation["parameters"]
|
|
if isinstance(param, dict) and param.get("in") != "query"
|
|
]
|
|
# Remove the parameters key if it's now empty
|
|
if not operation["parameters"]:
|
|
del operation["parameters"]
|
|
|
|
return openapi_schema
|
|
|
|
|
|
def _convert_multiline_strings_to_literal(obj: Any) -> Any:
|
|
"""Recursively convert multi-line strings to LiteralScalarString for YAML block scalar formatting."""
|
|
try:
|
|
from ruamel.yaml.scalarstring import LiteralScalarString
|
|
|
|
if isinstance(obj, str) and "\n" in obj:
|
|
return LiteralScalarString(obj)
|
|
elif isinstance(obj, dict):
|
|
return {key: _convert_multiline_strings_to_literal(value) for key, value in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [_convert_multiline_strings_to_literal(item) for item in obj]
|
|
else:
|
|
return obj
|
|
except ImportError:
|
|
return obj
|
|
|
|
|
|
def _write_yaml_file(file_path: Path, schema: dict[str, Any]) -> None:
|
|
"""Write schema to YAML file using ruamel.yaml if available, otherwise standard yaml."""
|
|
try:
|
|
from ruamel.yaml import YAML
|
|
|
|
yaml_writer = YAML()
|
|
yaml_writer.default_flow_style = False
|
|
yaml_writer.sort_keys = False
|
|
yaml_writer.width = 4096
|
|
yaml_writer.allow_unicode = True
|
|
schema = _convert_multiline_strings_to_literal(schema)
|
|
with open(file_path, "w") as f:
|
|
yaml_writer.dump(schema, f)
|
|
except ImportError:
|
|
with open(file_path, "w") as f:
|
|
yaml.dump(schema, f, default_flow_style=False, sort_keys=False)
|
|
|
|
|
|
def _filter_schema_by_version(
|
|
openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Filter OpenAPI schema by API version.
|
|
|
|
Args:
|
|
openapi_schema: The full OpenAPI schema
|
|
stable_only: If True, return only /v1/ paths (stable). If False, return only /v1alpha/ and /v1beta/ paths (experimental).
|
|
exclude_deprecated: If True, exclude deprecated endpoints from the result.
|
|
|
|
Returns:
|
|
Filtered OpenAPI schema
|
|
"""
|
|
filtered_schema = openapi_schema.copy()
|
|
|
|
if "paths" not in filtered_schema:
|
|
return filtered_schema
|
|
|
|
# Filter paths based on version prefix and deprecated status
|
|
filtered_paths = {}
|
|
for path, path_item in filtered_schema["paths"].items():
|
|
# Check if path has any deprecated operations
|
|
is_deprecated = _is_path_deprecated(path_item)
|
|
|
|
# Skip deprecated endpoints if exclude_deprecated is True
|
|
if exclude_deprecated and is_deprecated:
|
|
continue
|
|
|
|
if stable_only:
|
|
# Only include stable v1 paths, exclude v1alpha and v1beta
|
|
if _is_stable_path(path):
|
|
filtered_paths[path] = path_item
|
|
else:
|
|
# Only include experimental paths (v1alpha or v1beta), exclude v1
|
|
if _is_experimental_path(path):
|
|
filtered_paths[path] = path_item
|
|
|
|
filtered_schema["paths"] = filtered_paths
|
|
|
|
# Filter schemas/components to only include ones referenced by filtered paths
|
|
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
|
|
# Find all schemas that are actually referenced by the filtered paths
|
|
# Use the original schema to find all references, not the filtered one
|
|
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
|
|
|
# Also include all registered schemas and @json_schema_type decorated models
|
|
# (they should always be included) and all schemas they reference (transitive references)
|
|
from llama_stack.schema_utils import _registered_schemas
|
|
|
|
# Use the original schema to find registered schema definitions
|
|
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
|
registered_schema_names = set()
|
|
for registration_info in _registered_schemas.values():
|
|
registered_schema_names.add(registration_info["name"])
|
|
|
|
# Also include all @json_schema_type decorated models
|
|
json_schema_type_names = _get_all_json_schema_type_names()
|
|
all_explicit_schema_names = registered_schema_names | json_schema_type_names
|
|
|
|
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
|
|
additional_schemas = set()
|
|
for schema_name in all_explicit_schema_names:
|
|
referenced_schemas.add(schema_name)
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
# Keep adding transitive references until no new ones are found
|
|
while additional_schemas:
|
|
new_schemas = additional_schemas - referenced_schemas
|
|
if not new_schemas:
|
|
break
|
|
referenced_schemas.update(new_schemas)
|
|
additional_schemas = set()
|
|
for schema_name in new_schemas:
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
# Only keep schemas that are referenced by the filtered paths or are registered/@json_schema_type
|
|
filtered_schemas = {}
|
|
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
|
|
if schema_name in referenced_schemas:
|
|
filtered_schemas[schema_name] = schema_def
|
|
|
|
filtered_schema["components"]["schemas"] = filtered_schemas
|
|
|
|
# Preserve $defs section if it exists
|
|
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
|
|
if "components" not in filtered_schema:
|
|
filtered_schema["components"] = {}
|
|
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
|
|
|
|
return filtered_schema
|
|
|
|
|
|
def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]:
|
|
"""
|
|
Find all schemas that are referenced by the filtered paths.
|
|
This recursively traverses the path definitions to find all $ref references.
|
|
"""
|
|
referenced_schemas = set()
|
|
|
|
# Traverse all filtered paths
|
|
for _, path_item in filtered_paths.items():
|
|
if not isinstance(path_item, dict):
|
|
continue
|
|
|
|
# Check each HTTP method in the path
|
|
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
|
if method in path_item:
|
|
operation = path_item[method]
|
|
if isinstance(operation, dict):
|
|
# Find all schema references in this operation
|
|
referenced_schemas.update(_find_schema_refs_in_object(operation))
|
|
|
|
# Also check the responses section for schema references
|
|
if "components" in openapi_schema and "responses" in openapi_schema["components"]:
|
|
referenced_schemas.update(_find_schema_refs_in_object(openapi_schema["components"]["responses"]))
|
|
|
|
# Also include schemas that are referenced by other schemas (transitive references)
|
|
# This ensures we include all dependencies
|
|
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
|
additional_schemas = set()
|
|
|
|
for schema_name in referenced_schemas:
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
# Keep adding transitive references until no new ones are found
|
|
while additional_schemas:
|
|
new_schemas = additional_schemas - referenced_schemas
|
|
if not new_schemas:
|
|
break
|
|
referenced_schemas.update(new_schemas)
|
|
additional_schemas = set()
|
|
for schema_name in new_schemas:
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
return referenced_schemas
|
|
|
|
|
|
def _find_schema_refs_in_object(obj: Any) -> set[str]:
|
|
"""
|
|
Recursively find all schema references ($ref) in an object.
|
|
"""
|
|
refs = set()
|
|
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
if key == "$ref" and isinstance(value, str) and value.startswith("#/components/schemas/"):
|
|
schema_name = value.split("/")[-1]
|
|
refs.add(schema_name)
|
|
else:
|
|
refs.update(_find_schema_refs_in_object(value))
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
refs.update(_find_schema_refs_in_object(item))
|
|
|
|
return refs
|
|
|
|
|
|
def _get_all_json_schema_type_names() -> set[str]:
|
|
"""
|
|
Get all schema names from @json_schema_type decorated models.
|
|
This ensures they are included in filtered schemas even if not directly referenced by paths.
|
|
"""
|
|
schema_names = set()
|
|
apis_modules = _import_all_modules_in_package("llama_stack.apis")
|
|
for module in apis_modules:
|
|
for attr_name in dir(module):
|
|
try:
|
|
attr = getattr(module, attr_name)
|
|
if (
|
|
hasattr(attr, "_llama_stack_schema_type")
|
|
and hasattr(attr, "model_json_schema")
|
|
and hasattr(attr, "__name__")
|
|
):
|
|
schema_names.add(attr.__name__)
|
|
except (AttributeError, TypeError):
|
|
continue
|
|
return schema_names
|
|
|
|
|
|
def _is_path_deprecated(path_item: dict[str, Any]) -> bool:
|
|
"""Check if a path item has any deprecated operations."""
|
|
if not isinstance(path_item, dict):
|
|
return False
|
|
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
|
if isinstance(path_item.get(method), dict) and path_item[method].get("deprecated", False):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _path_starts_with_version(path: str, version: str) -> bool:
|
|
"""Check if a path starts with a specific API version prefix."""
|
|
return path.startswith(f"/{version}/")
|
|
|
|
|
|
def _is_stable_path(path: str) -> bool:
|
|
"""Check if a path is a stable v1 path (not v1alpha or v1beta)."""
|
|
return (
|
|
_path_starts_with_version(path, LLAMA_STACK_API_V1)
|
|
and not _path_starts_with_version(path, LLAMA_STACK_API_V1ALPHA)
|
|
and not _path_starts_with_version(path, LLAMA_STACK_API_V1BETA)
|
|
)
|
|
|
|
|
|
def _is_experimental_path(path: str) -> bool:
|
|
"""Check if a path is an experimental path (v1alpha or v1beta)."""
|
|
return _path_starts_with_version(path, LLAMA_STACK_API_V1ALPHA) or _path_starts_with_version(
|
|
path, LLAMA_STACK_API_V1BETA
|
|
)
|
|
|
|
|
|
def _filter_deprecated_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Filter OpenAPI schema to include only deprecated endpoints.
|
|
Includes all deprecated endpoints regardless of version (v1, v1alpha, v1beta).
|
|
"""
|
|
filtered_schema = openapi_schema.copy()
|
|
|
|
if "paths" not in filtered_schema:
|
|
return filtered_schema
|
|
|
|
# Filter paths to only include deprecated ones
|
|
filtered_paths = {}
|
|
for path, path_item in filtered_schema["paths"].items():
|
|
if _is_path_deprecated(path_item):
|
|
filtered_paths[path] = path_item
|
|
|
|
filtered_schema["paths"] = filtered_paths
|
|
|
|
return filtered_schema
|
|
|
|
|
|
def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Filter OpenAPI schema to include both stable (v1) and experimental (v1alpha, v1beta) APIs.
|
|
Excludes deprecated endpoints. This is used for the combined "stainless" spec.
|
|
"""
|
|
filtered_schema = openapi_schema.copy()
|
|
|
|
if "paths" not in filtered_schema:
|
|
return filtered_schema
|
|
|
|
# Filter paths to include stable (v1) and experimental (v1alpha, v1beta), excluding deprecated
|
|
filtered_paths = {}
|
|
for path, path_item in filtered_schema["paths"].items():
|
|
# Check if path has any deprecated operations
|
|
is_deprecated = _is_path_deprecated(path_item)
|
|
|
|
# Skip deprecated endpoints
|
|
if is_deprecated:
|
|
continue
|
|
|
|
# Include stable v1 paths
|
|
if _is_stable_path(path):
|
|
filtered_paths[path] = path_item
|
|
# Include experimental paths (v1alpha or v1beta)
|
|
elif _is_experimental_path(path):
|
|
filtered_paths[path] = path_item
|
|
|
|
filtered_schema["paths"] = filtered_paths
|
|
|
|
# Filter schemas/components to only include ones referenced by filtered paths
|
|
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
|
|
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
|
|
|
# Also include all registered schemas and @json_schema_type decorated models
|
|
# (they should always be included) and all schemas they reference (transitive references)
|
|
from llama_stack.schema_utils import _registered_schemas
|
|
|
|
# Use the original schema to find registered schema definitions
|
|
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
|
registered_schema_names = set()
|
|
for registration_info in _registered_schemas.values():
|
|
registered_schema_names.add(registration_info["name"])
|
|
|
|
# Also include all @json_schema_type decorated models
|
|
json_schema_type_names = _get_all_json_schema_type_names()
|
|
all_explicit_schema_names = registered_schema_names | json_schema_type_names
|
|
|
|
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
|
|
additional_schemas = set()
|
|
for schema_name in all_explicit_schema_names:
|
|
referenced_schemas.add(schema_name)
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
# Keep adding transitive references until no new ones are found
|
|
while additional_schemas:
|
|
new_schemas = additional_schemas - referenced_schemas
|
|
if not new_schemas:
|
|
break
|
|
referenced_schemas.update(new_schemas)
|
|
additional_schemas = set()
|
|
for schema_name in new_schemas:
|
|
if schema_name in all_schemas:
|
|
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
|
|
|
filtered_schemas = {}
|
|
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
|
|
if schema_name in referenced_schemas:
|
|
filtered_schemas[schema_name] = schema_def
|
|
|
|
filtered_schema["components"]["schemas"] = filtered_schemas
|
|
|
|
return filtered_schema
|
|
|
|
|
|
def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|
"""
|
|
Generate OpenAPI specification using FastAPI's built-in method.
|
|
|
|
Args:
|
|
output_dir: Directory to save the generated files
|
|
|
|
Returns:
|
|
The generated OpenAPI specification as a dictionary
|
|
"""
|
|
# Create the FastAPI app
|
|
app = create_llama_stack_app()
|
|
|
|
# Generate the OpenAPI schema
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
servers=app.servers,
|
|
)
|
|
|
|
# Set OpenAPI version to 3.1.0
|
|
openapi_schema["openapi"] = "3.1.0"
|
|
|
|
# Add standard error responses
|
|
openapi_schema = _add_error_responses(openapi_schema)
|
|
|
|
# Ensure all @json_schema_type decorated models are included
|
|
openapi_schema = _ensure_json_schema_types_included(openapi_schema)
|
|
|
|
# Fix $ref references to point to components/schemas instead of $defs
|
|
openapi_schema = _fix_ref_references(openapi_schema)
|
|
|
|
# Fix path parameter resolution issues
|
|
openapi_schema = _fix_path_parameters(openapi_schema)
|
|
|
|
# Eliminate $defs section entirely for oasdiff compatibility
|
|
openapi_schema = _eliminate_defs_section(openapi_schema)
|
|
|
|
# Clean descriptions in schema definitions by removing docstring metadata
|
|
openapi_schema = _clean_schema_descriptions(openapi_schema)
|
|
|
|
# Remove query parameters from POST/PUT/PATCH endpoints that have a request body
|
|
# FastAPI sometimes infers parameters as query params even when they should be in the request body
|
|
openapi_schema = _remove_query_params_from_body_endpoints(openapi_schema)
|
|
|
|
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
|
|
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
|
|
import copy
|
|
|
|
stable_schema = _filter_schema_by_version(copy.deepcopy(openapi_schema), stable_only=True, exclude_deprecated=True)
|
|
experimental_schema = _filter_schema_by_version(
|
|
copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True
|
|
)
|
|
deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema))
|
|
combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema))
|
|
|
|
# Base description for all specs
|
|
base_description = (
|
|
"This is the specification of the Llama Stack that provides\n"
|
|
" a set of endpoints and their corresponding interfaces that are\n"
|
|
" tailored to\n"
|
|
" best leverage Llama Models."
|
|
)
|
|
|
|
# Update info section for stable schema
|
|
if "info" not in stable_schema:
|
|
stable_schema["info"] = {}
|
|
stable_schema["info"]["title"] = "Llama Stack Specification"
|
|
stable_schema["info"]["version"] = "v1"
|
|
stable_schema["info"]["description"] = (
|
|
base_description + "\n\n **✅ STABLE**: Production-ready APIs with backward compatibility guarantees."
|
|
)
|
|
|
|
# Update info section for experimental schema
|
|
if "info" not in experimental_schema:
|
|
experimental_schema["info"] = {}
|
|
experimental_schema["info"]["title"] = "Llama Stack Specification - Experimental APIs"
|
|
experimental_schema["info"]["version"] = "v1"
|
|
experimental_schema["info"]["description"] = (
|
|
base_description + "\n\n **🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n"
|
|
" becoming stable."
|
|
)
|
|
|
|
# Update info section for deprecated schema
|
|
if "info" not in deprecated_schema:
|
|
deprecated_schema["info"] = {}
|
|
deprecated_schema["info"]["title"] = "Llama Stack Specification - Deprecated APIs"
|
|
deprecated_schema["info"]["version"] = "v1"
|
|
deprecated_schema["info"]["description"] = (
|
|
base_description + "\n\n **⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n"
|
|
" migration reference only."
|
|
)
|
|
|
|
# Update info section for combined schema
|
|
if "info" not in combined_schema:
|
|
combined_schema["info"] = {}
|
|
combined_schema["info"]["title"] = "Llama Stack Specification - Stable & Experimental APIs"
|
|
combined_schema["info"]["version"] = "v1"
|
|
combined_schema["info"]["description"] = (
|
|
base_description + "\n\n\n"
|
|
" **🔗 COMBINED**: This specification includes both stable production-ready APIs\n"
|
|
" and experimental pre-release APIs. Use stable APIs for production deployments\n"
|
|
" and experimental APIs for testing new features."
|
|
)
|
|
|
|
# Fix schema issues (like exclusiveMinimum -> minimum) for each spec
|
|
stable_schema = _fix_schema_issues(stable_schema)
|
|
experimental_schema = _fix_schema_issues(experimental_schema)
|
|
deprecated_schema = _fix_schema_issues(deprecated_schema)
|
|
combined_schema = _fix_schema_issues(combined_schema)
|
|
|
|
# Validate the schemas
|
|
print("\n🔍 Validating generated schemas...")
|
|
stable_valid = validate_openapi_schema(stable_schema, "Stable schema")
|
|
experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema")
|
|
deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema")
|
|
combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema")
|
|
|
|
if not all([stable_valid, experimental_valid, deprecated_valid, combined_valid]):
|
|
print("⚠️ Some schemas failed validation, but continuing with generation...")
|
|
|
|
# Ensure output directory exists
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save the stable specification
|
|
yaml_path = output_path / "llama-stack-spec.yaml"
|
|
_write_yaml_file(yaml_path, stable_schema)
|
|
# Post-process the YAML file to remove $defs section and fix references
|
|
with open(yaml_path) as f:
|
|
yaml_content = f.read()
|
|
|
|
if " $defs:" in yaml_content or "#/$defs/" in yaml_content:
|
|
# Use string replacement to fix references directly
|
|
if "#/$defs/" in yaml_content:
|
|
yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/")
|
|
|
|
# Parse the YAML content
|
|
yaml_data = yaml.safe_load(yaml_content)
|
|
|
|
# Move $defs to components/schemas if it exists
|
|
if "$defs" in yaml_data:
|
|
if "components" not in yaml_data:
|
|
yaml_data["components"] = {}
|
|
if "schemas" not in yaml_data["components"]:
|
|
yaml_data["components"]["schemas"] = {}
|
|
|
|
# Move all $defs to components/schemas
|
|
for def_name, def_schema in yaml_data["$defs"].items():
|
|
yaml_data["components"]["schemas"][def_name] = def_schema
|
|
|
|
# Remove the $defs section
|
|
del yaml_data["$defs"]
|
|
|
|
# Write the modified YAML back
|
|
_write_yaml_file(yaml_path, yaml_data)
|
|
|
|
print(f"✅ Generated YAML (stable): {yaml_path}")
|
|
|
|
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
|
|
_write_yaml_file(experimental_yaml_path, experimental_schema)
|
|
print(f"✅ Generated YAML (experimental): {experimental_yaml_path}")
|
|
|
|
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
|
|
_write_yaml_file(deprecated_yaml_path, deprecated_schema)
|
|
print(f"✅ Generated YAML (deprecated): {deprecated_yaml_path}")
|
|
|
|
# Generate combined (stainless) spec
|
|
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
|
|
_write_yaml_file(stainless_yaml_path, combined_schema)
|
|
print(f"✅ Generated YAML (stainless/combined): {stainless_yaml_path}")
|
|
|
|
return stable_schema
|
|
|
|
|
|
def main():
|
|
"""Main entry point for the FastAPI OpenAPI generator."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Generate OpenAPI specification using FastAPI")
|
|
parser.add_argument("output_dir", help="Output directory for generated files")
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("🚀 Generating OpenAPI specification using FastAPI...")
|
|
print(f"📁 Output directory: {args.output_dir}")
|
|
|
|
try:
|
|
openapi_schema = generate_openapi_spec(output_dir=args.output_dir)
|
|
|
|
print("\n✅ OpenAPI specification generated successfully!")
|
|
print(f"📊 Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
|
|
print(f"🛣️ Paths: {len(openapi_schema.get('paths', {}))}")
|
|
operation_count = sum(
|
|
1
|
|
for path_info in openapi_schema.get("paths", {}).values()
|
|
for method in ["get", "post", "put", "delete", "patch"]
|
|
if method in path_info
|
|
)
|
|
print(f"🔧 Operations: {operation_count}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error generating OpenAPI specification: {e}")
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|