mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
wip
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
ec702ac3fb
commit
38de8ea1f7
12 changed files with 26571 additions and 24377 deletions
|
|
@ -17,6 +17,8 @@ from typing import Annotated, Any, Literal, get_args, get_origin
|
|||
import yaml
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from openapi_spec_validator import validate_spec
|
||||
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
|
|
@ -24,6 +26,9 @@ from llama_stack.core.resolver import api_protocol_map
|
|||
# Import the existing route discovery system
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
||||
# Global list to store dynamic models created during endpoint generation
|
||||
_dynamic_models = []
|
||||
|
||||
|
||||
def _get_all_api_routes_with_functions():
|
||||
"""
|
||||
|
|
@ -108,6 +113,37 @@ def create_llama_stack_app() -> FastAPI:
|
|||
return app
|
||||
|
||||
|
||||
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract path parameters from a URL path and return them as OpenAPI parameter definitions.
|
||||
|
||||
Args:
|
||||
path: URL path with parameters like /v1/batches/{batch_id}/cancel
|
||||
|
||||
Returns:
|
||||
List of parameter definitions for OpenAPI
|
||||
"""
|
||||
import re
|
||||
|
||||
# Find all path parameters in the format {param} or {param:type}
|
||||
param_pattern = r"\{([^}:]+)(?::[^}]+)?\}"
|
||||
matches = re.findall(param_pattern, path)
|
||||
|
||||
parameters = []
|
||||
for param_name in matches:
|
||||
parameters.append(
|
||||
{
|
||||
"name": param_name,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
"description": f"Path parameter: {param_name}",
|
||||
}
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _create_fastapi_endpoint(app: FastAPI, route, webmethod):
|
||||
"""
|
||||
Create a FastAPI endpoint from a discovered route and webmethod.
|
||||
|
|
@ -124,6 +160,12 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod):
|
|||
# Try to find actual models for this endpoint
|
||||
request_model, response_model, query_parameters = _find_models_for_endpoint(webmethod)
|
||||
|
||||
# Debug: Print info for safety endpoints
|
||||
if "safety" in webmethod.route or "shield" in webmethod.route:
|
||||
print(
|
||||
f"Debug: {webmethod.route} - request_model: {request_model}, response_model: {response_model}, query_parameters: {query_parameters}"
|
||||
)
|
||||
|
||||
# Extract response description from webmethod docstring (always try this first)
|
||||
response_description = _extract_response_description_from_docstring(webmethod, response_model)
|
||||
|
||||
|
|
@ -136,46 +178,107 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod):
|
|||
|
||||
endpoint_func = typed_endpoint
|
||||
elif response_model and query_parameters:
|
||||
# Request with individual parameters (could be GET with query params or POST with individual params)
|
||||
# Create a function with the actual query parameters
|
||||
def create_query_endpoint_func():
|
||||
# Build the function signature dynamically
|
||||
import inspect
|
||||
# Check if this is a POST/PUT endpoint with individual parameters
|
||||
# For POST/PUT, individual parameters should go in request body, not query params
|
||||
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
||||
|
||||
# Create parameter annotations
|
||||
param_annotations = {}
|
||||
param_defaults = {}
|
||||
if is_post_put:
|
||||
# POST/PUT with individual parameters - create a request body model
|
||||
try:
|
||||
from pydantic import create_model
|
||||
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
# Handle problematic type annotations that cause FastAPI issues
|
||||
safe_type = _make_type_safe_for_fastapi(param_type)
|
||||
param_annotations[param_name] = safe_type
|
||||
if default_value is not None:
|
||||
param_defaults[param_name] = default_value
|
||||
# Create a dynamic Pydantic model for the request body
|
||||
field_definitions = {}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
# Handle complex types that might cause issues with create_model
|
||||
safe_type = _make_type_safe_for_fastapi(param_type)
|
||||
|
||||
# Create the function signature
|
||||
sig = inspect.Signature(
|
||||
[
|
||||
inspect.Parameter(
|
||||
name=param_name,
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default_value,
|
||||
annotation=param_annotations[param_name],
|
||||
if default_value is None:
|
||||
field_definitions[param_name] = (safe_type, ...) # Required field
|
||||
else:
|
||||
field_definitions[param_name] = (safe_type, default_value) # Optional field with default
|
||||
|
||||
# Create the request model dynamically
|
||||
# Clean up the route name to create a valid schema name
|
||||
clean_route = webmethod.route.replace("/", "_").replace("{", "").replace("}", "").replace("-", "_")
|
||||
model_name = f"{clean_route}_Request"
|
||||
|
||||
print(f"Debug: Creating model {model_name} with fields: {field_definitions}")
|
||||
request_model = create_model(model_name, **field_definitions)
|
||||
print(f"Debug: Successfully created model {model_name}")
|
||||
|
||||
# Store the dynamic model in the global list for schema inclusion
|
||||
_dynamic_models.append(request_model)
|
||||
|
||||
# Create endpoint with request body
|
||||
async def typed_endpoint(request: request_model) -> response_model:
|
||||
"""Typed endpoint for proper schema generation."""
|
||||
return response_model()
|
||||
|
||||
# Set the function signature to ensure FastAPI recognizes the request model
|
||||
typed_endpoint.__annotations__ = {"request": request_model, "return": response_model}
|
||||
|
||||
endpoint_func = typed_endpoint
|
||||
except Exception as e:
|
||||
# If dynamic model creation fails, fall back to query parameters
|
||||
print(f"Warning: Failed to create dynamic request model for {webmethod.route}: {e}")
|
||||
print(f" Query parameters: {query_parameters}")
|
||||
# Fall through to the query parameter handling
|
||||
pass
|
||||
|
||||
if not is_post_put:
|
||||
# GET with query parameters - create a function with the actual query parameters
|
||||
def create_query_endpoint_func():
|
||||
# Build the function signature dynamically
|
||||
import inspect
|
||||
|
||||
# Create parameter annotations
|
||||
param_annotations = {}
|
||||
param_defaults = {}
|
||||
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
# Handle problematic type annotations that cause FastAPI issues
|
||||
safe_type = _make_type_safe_for_fastapi(param_type)
|
||||
param_annotations[param_name] = safe_type
|
||||
if default_value is not None:
|
||||
param_defaults[param_name] = default_value
|
||||
|
||||
# Create the function with the correct signature
|
||||
def create_endpoint_func():
|
||||
# Sort parameters so that required parameters come before optional ones
|
||||
# Parameters with None default are required, others are optional
|
||||
sorted_params = sorted(
|
||||
query_parameters,
|
||||
key=lambda x: (x[2] is not None, x[0]), # False (required) comes before True (optional)
|
||||
)
|
||||
for param_name, param_type, default_value in query_parameters
|
||||
]
|
||||
)
|
||||
|
||||
async def query_endpoint(**kwargs) -> response_model:
|
||||
"""Query endpoint for proper schema generation."""
|
||||
return response_model()
|
||||
# Create the function signature
|
||||
sig = inspect.Signature(
|
||||
[
|
||||
inspect.Parameter(
|
||||
name=param_name,
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default_value if default_value is not None else inspect.Parameter.empty,
|
||||
annotation=param_annotations[param_name],
|
||||
)
|
||||
for param_name, param_type, default_value in sorted_params
|
||||
]
|
||||
)
|
||||
|
||||
# Set the signature
|
||||
query_endpoint.__signature__ = sig
|
||||
query_endpoint.__annotations__ = param_annotations
|
||||
return query_endpoint
|
||||
# Create a simple function without **kwargs
|
||||
async def query_endpoint():
|
||||
"""Query endpoint for proper schema generation."""
|
||||
return response_model()
|
||||
|
||||
endpoint_func = create_query_endpoint_func()
|
||||
# Set the signature and annotations
|
||||
query_endpoint.__signature__ = sig
|
||||
query_endpoint.__annotations__ = param_annotations
|
||||
|
||||
return query_endpoint
|
||||
|
||||
return create_endpoint_func()
|
||||
|
||||
endpoint_func = create_query_endpoint_func()
|
||||
elif response_model:
|
||||
# Response-only endpoint (no parameters)
|
||||
async def response_only_endpoint() -> response_model:
|
||||
|
|
@ -289,6 +392,10 @@ def _find_models_for_endpoint(webmethod) -> tuple[type | None, type | None, list
|
|||
if param_name == "self":
|
||||
continue
|
||||
|
||||
# Skip *args and **kwargs parameters - these are not real API parameters
|
||||
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
# Check if it's a Pydantic model (for POST/PUT requests)
|
||||
param_type = param.annotation
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
|
|
@ -319,8 +426,17 @@ def _find_models_for_endpoint(webmethod) -> tuple[type | None, type | None, list
|
|||
elif get_origin(return_annotation) is Annotated:
|
||||
# Handle Annotated return types
|
||||
args = get_args(return_annotation)
|
||||
if args and hasattr(args[0], "model_json_schema"):
|
||||
response_model = args[0]
|
||||
if args:
|
||||
# Check if the first argument is a Pydantic model
|
||||
if hasattr(args[0], "model_json_schema"):
|
||||
response_model = args[0]
|
||||
# Check if the first argument is a union type
|
||||
elif get_origin(args[0]) is type(args[0]): # Union type
|
||||
union_args = get_args(args[0])
|
||||
for arg in union_args:
|
||||
if hasattr(arg, "model_json_schema"):
|
||||
response_model = arg
|
||||
break
|
||||
elif get_origin(return_annotation) is type(return_annotation): # Union type
|
||||
# Handle union types - try to find the first Pydantic model
|
||||
args = get_args(return_annotation)
|
||||
|
|
@ -340,6 +456,7 @@ def _make_type_safe_for_fastapi(type_hint) -> type:
|
|||
"""
|
||||
Make a type hint safe for FastAPI by converting problematic types to their base types.
|
||||
This handles cases like Literal["24h"] that cause forward reference errors.
|
||||
Also removes Union with None to avoid anyOf with type: 'null' schemas.
|
||||
"""
|
||||
# Handle Literal types that might cause issues
|
||||
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Literal:
|
||||
|
|
@ -369,11 +486,16 @@ def _make_type_safe_for_fastapi(type_hint) -> type:
|
|||
if origin is type(type_hint) or (hasattr(type_hint, "__args__") and type_hint.__args__):
|
||||
# This is a union type, find the non-None type
|
||||
args = get_args(type_hint)
|
||||
for arg in args:
|
||||
if arg is not type(None) and arg is not None:
|
||||
return arg
|
||||
# If all args are None, return the first one
|
||||
return args[0] if args else type_hint
|
||||
non_none_types = [arg for arg in args if arg is not type(None) and arg is not None]
|
||||
|
||||
if non_none_types:
|
||||
# Return the first non-None type to avoid anyOf with null
|
||||
return non_none_types[0]
|
||||
elif args:
|
||||
# If all args are None, return the first one
|
||||
return args[0]
|
||||
else:
|
||||
return type_hint
|
||||
|
||||
# Not a union type, return as-is
|
||||
return type_hint
|
||||
|
|
@ -475,6 +597,202 @@ def _find_extra_body_params_for_route(api_name: str, route, webmethod) -> list[d
|
|||
return []
|
||||
|
||||
|
||||
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Ensure all @json_schema_type decorated models are included in the OpenAPI schema.
|
||||
This finds all models with the _llama_stack_schema_type attribute and adds them to the schema.
|
||||
"""
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# Find all classes with the _llama_stack_schema_type attribute
|
||||
from llama_stack import apis
|
||||
|
||||
# Get all modules in the apis package
|
||||
apis_modules = []
|
||||
for module_name in dir(apis):
|
||||
if not module_name.startswith("_"):
|
||||
try:
|
||||
module = getattr(apis, module_name)
|
||||
if hasattr(module, "__file__"):
|
||||
apis_modules.append(module)
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
|
||||
# Also check submodules
|
||||
for module in apis_modules:
|
||||
for attr_name in dir(module):
|
||||
if not attr_name.startswith("_"):
|
||||
try:
|
||||
attr = getattr(module, attr_name)
|
||||
if hasattr(attr, "__file__") and hasattr(attr, "__name__"):
|
||||
apis_modules.append(attr)
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
|
||||
# Find all classes with the _llama_stack_schema_type attribute
|
||||
for module in apis_modules:
|
||||
for attr_name in dir(module):
|
||||
try:
|
||||
attr = getattr(module, attr_name)
|
||||
if (
|
||||
hasattr(attr, "_llama_stack_schema_type")
|
||||
and hasattr(attr, "model_json_schema")
|
||||
and hasattr(attr, "__name__")
|
||||
):
|
||||
schema_name = attr.__name__
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
schema = attr.model_json_schema()
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception:
|
||||
# Skip if we can't generate the schema
|
||||
continue
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
# Also include any dynamic models that were created during endpoint generation
|
||||
# This is a workaround to ensure dynamic models appear in the schema
|
||||
global _dynamic_models
|
||||
if "_dynamic_models" in globals():
|
||||
for model in _dynamic_models:
|
||||
try:
|
||||
schema_name = model.__name__
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
schema = model.model_json_schema()
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception:
|
||||
# Skip if we can't generate the schema
|
||||
continue
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_ref_references(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix $ref references to point to components/schemas instead of $defs.
|
||||
This prevents the YAML dumper from creating a root-level $defs section.
|
||||
"""
|
||||
|
||||
def fix_refs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
||||
# Replace #/$defs/ with #/components/schemas/
|
||||
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
||||
for value in obj.values():
|
||||
fix_refs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
fix_refs(item)
|
||||
|
||||
fix_refs(openapi_schema)
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_anyof_with_null(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix anyOf schemas that contain type: 'null' by removing the null type
|
||||
and making the field optional through the required field instead.
|
||||
"""
|
||||
|
||||
def fix_anyof(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "anyOf" in obj and isinstance(obj["anyOf"], list):
|
||||
# Check if anyOf contains type: 'null'
|
||||
has_null = any(item.get("type") == "null" for item in obj["anyOf"] if isinstance(item, dict))
|
||||
if has_null:
|
||||
# Remove null types and keep only the non-null types
|
||||
non_null_types = [
|
||||
item for item in obj["anyOf"] if not (isinstance(item, dict) and item.get("type") == "null")
|
||||
]
|
||||
if len(non_null_types) == 1:
|
||||
# If only one non-null type remains, replace anyOf with that type
|
||||
obj.update(non_null_types[0])
|
||||
if "anyOf" in obj:
|
||||
del obj["anyOf"]
|
||||
else:
|
||||
# Keep the anyOf but without null types
|
||||
obj["anyOf"] = non_null_types
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
fix_anyof(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
fix_anyof(item)
|
||||
|
||||
fix_anyof(openapi_schema)
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _eliminate_defs_section(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Eliminate $defs section entirely by moving all definitions to components/schemas.
|
||||
This matches the structure of the old pyopenapi generator for oasdiff compatibility.
|
||||
"""
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
# First pass: collect all $defs from anywhere in the schema
|
||||
defs_to_move = {}
|
||||
|
||||
def collect_defs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$defs" in obj:
|
||||
# Collect $defs for later processing
|
||||
for def_name, def_schema in obj["$defs"].items():
|
||||
if def_name not in defs_to_move:
|
||||
defs_to_move[def_name] = def_schema
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
collect_defs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
collect_defs(item)
|
||||
|
||||
# Collect all $defs
|
||||
collect_defs(openapi_schema)
|
||||
|
||||
# Move all $defs to components/schemas
|
||||
for def_name, def_schema in defs_to_move.items():
|
||||
if def_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][def_name] = def_schema
|
||||
|
||||
# Also move any existing root-level $defs to components/schemas
|
||||
if "$defs" in openapi_schema:
|
||||
print(f"Found root-level $defs with {len(openapi_schema['$defs'])} items, moving to components/schemas")
|
||||
for def_name, def_schema in openapi_schema["$defs"].items():
|
||||
if def_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][def_name] = def_schema
|
||||
# Remove the root-level $defs
|
||||
del openapi_schema["$defs"]
|
||||
|
||||
# Second pass: remove all $defs sections from anywhere in the schema
|
||||
def remove_defs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$defs" in obj:
|
||||
del obj["$defs"]
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
remove_defs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
remove_defs(item)
|
||||
|
||||
# Remove all $defs sections
|
||||
remove_defs(openapi_schema)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Add standard error response definitions to the OpenAPI schema.
|
||||
|
|
@ -547,10 +865,40 @@ def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_path_parameters(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix path parameter resolution issues by adding explicit parameter definitions.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
# Extract path parameters from the URL
|
||||
path_params = _extract_path_parameters(path)
|
||||
|
||||
if not path_params:
|
||||
continue
|
||||
|
||||
# Add parameters to each operation in this path
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method in path_item and isinstance(path_item[method], dict):
|
||||
operation = path_item[method]
|
||||
if "parameters" not in operation:
|
||||
operation["parameters"] = []
|
||||
|
||||
# Add path parameters that aren't already defined
|
||||
existing_param_names = {p.get("name") for p in operation["parameters"] if p.get("in") == "path"}
|
||||
for param in path_params:
|
||||
if param["name"] not in existing_param_names:
|
||||
operation["parameters"].append(param)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix common schema issues that cause OpenAPI validation problems.
|
||||
This includes converting exclusiveMinimum numbers to minimum values.
|
||||
This includes converting exclusiveMinimum numbers to minimum values and fixing string fields with null defaults.
|
||||
"""
|
||||
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
|
||||
return openapi_schema
|
||||
|
|
@ -560,10 +908,64 @@ def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
# Fix exclusiveMinimum issues
|
||||
for _, schema_def in schemas.items():
|
||||
_fix_exclusive_minimum_in_schema(schema_def)
|
||||
_fix_all_null_defaults(schema_def)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI schema") -> bool:
|
||||
"""
|
||||
Validate an OpenAPI schema using openapi-spec-validator.
|
||||
|
||||
Args:
|
||||
schema: The OpenAPI schema dictionary to validate
|
||||
schema_name: Name of the schema for error reporting
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
|
||||
Raises:
|
||||
OpenAPIValidationError: If validation fails
|
||||
"""
|
||||
try:
|
||||
validate_spec(schema)
|
||||
print(f"✅ {schema_name} is valid")
|
||||
return True
|
||||
except OpenAPISpecValidatorError as e:
|
||||
print(f"❌ {schema_name} validation failed:")
|
||||
print(f" {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ {schema_name} validation error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_schema_file(file_path: Path) -> bool:
|
||||
"""
|
||||
Validate an OpenAPI schema file (YAML or JSON).
|
||||
|
||||
Args:
|
||||
file_path: Path to the schema file
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
with open(file_path) as f:
|
||||
if file_path.suffix.lower() in [".yaml", ".yml"]:
|
||||
schema = yaml.safe_load(f)
|
||||
elif file_path.suffix.lower() == ".json":
|
||||
schema = json.load(f)
|
||||
else:
|
||||
print(f"❌ Unsupported file format: {file_path.suffix}")
|
||||
return False
|
||||
|
||||
return validate_openapi_schema(schema, str(file_path))
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to read {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _fix_exclusive_minimum_in_schema(obj: Any) -> None:
|
||||
"""
|
||||
Recursively fix exclusiveMinimum issues in a schema object.
|
||||
|
|
@ -586,6 +988,75 @@ def _fix_exclusive_minimum_in_schema(obj: Any) -> None:
|
|||
_fix_exclusive_minimum_in_schema(item)
|
||||
|
||||
|
||||
def _fix_string_fields_with_null_defaults(obj: Any) -> None:
|
||||
"""
|
||||
Recursively fix string fields that have default: null.
|
||||
This violates OpenAPI spec - string fields should either have a string default or be optional.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
# Check if this is a field definition with type: string and default: null
|
||||
if obj.get("type") == "string" and "default" in obj and obj["default"] is None:
|
||||
# Remove the default: null to make the field optional
|
||||
del obj["default"]
|
||||
# Add nullable: true to indicate the field can be null
|
||||
obj["nullable"] = True
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
_fix_string_fields_with_null_defaults(value)
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Recursively process all items
|
||||
for item in obj:
|
||||
_fix_string_fields_with_null_defaults(item)
|
||||
|
||||
|
||||
def _fix_anyof_with_null_defaults(obj: Any) -> None:
|
||||
"""
|
||||
Recursively fix anyOf schemas that have default: null.
|
||||
This violates OpenAPI spec - anyOf fields should not have null defaults.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
# Check if this is a field definition with anyOf and default: null
|
||||
if "anyOf" in obj and "default" in obj and obj["default"] is None:
|
||||
# Remove the default: null to make the field optional
|
||||
del obj["default"]
|
||||
# Add nullable: true to indicate the field can be null
|
||||
obj["nullable"] = True
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
_fix_anyof_with_null_defaults(value)
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Recursively process all items
|
||||
for item in obj:
|
||||
_fix_anyof_with_null_defaults(item)
|
||||
|
||||
|
||||
def _fix_all_null_defaults(obj: Any) -> None:
|
||||
"""
|
||||
Recursively fix all field types that have default: null.
|
||||
This violates OpenAPI spec - fields should not have null defaults.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
# Check if this is a field definition with default: null
|
||||
if "default" in obj and obj["default"] is None:
|
||||
# Remove the default: null to make the field optional
|
||||
del obj["default"]
|
||||
# Add nullable: true to indicate the field can be null
|
||||
obj["nullable"] = True
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
_fix_all_null_defaults(value)
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Recursively process all items
|
||||
for item in obj:
|
||||
_fix_all_null_defaults(item)
|
||||
|
||||
|
||||
def _sort_paths_alphabetically(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Sort the paths in the OpenAPI schema by version prefix first, then alphabetically.
|
||||
|
|
@ -703,6 +1174,15 @@ def _filter_schema_by_version(
|
|||
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
# Preserve $defs section if it exists
|
||||
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
|
||||
if "components" not in filtered_schema:
|
||||
filtered_schema["components"] = {}
|
||||
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
|
||||
print(f"Preserved $defs section with {len(openapi_schema['components']['$defs'])} items")
|
||||
else:
|
||||
print("No $defs section to preserve")
|
||||
|
||||
return filtered_schema
|
||||
|
||||
|
||||
|
|
@ -811,6 +1291,49 @@ def _filter_deprecated_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
return filtered_schema
|
||||
|
||||
|
||||
def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Filter OpenAPI schema to include both stable (v1) and experimental (v1alpha, v1beta) APIs.
|
||||
Excludes deprecated endpoints. This is used for the combined "stainless" spec.
|
||||
"""
|
||||
filtered_schema = openapi_schema.copy()
|
||||
|
||||
if "paths" not in filtered_schema:
|
||||
return filtered_schema
|
||||
|
||||
# Filter paths to include stable (v1) and experimental (v1alpha, v1beta), excluding deprecated
|
||||
filtered_paths = {}
|
||||
for path, path_item in filtered_schema["paths"].items():
|
||||
# Check if path has any deprecated operations
|
||||
is_deprecated = _is_path_deprecated(path_item)
|
||||
|
||||
# Skip deprecated endpoints
|
||||
if is_deprecated:
|
||||
continue
|
||||
|
||||
# Include /v1/ paths (stable)
|
||||
if path.startswith("/v1/") and not path.startswith("/v1alpha/") and not path.startswith("/v1beta/"):
|
||||
filtered_paths[path] = path_item
|
||||
# Include /v1alpha/ and /v1beta/ paths (experimental)
|
||||
elif path.startswith("/v1alpha/") or path.startswith("/v1beta/"):
|
||||
filtered_paths[path] = path_item
|
||||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
|
||||
# Filter schemas/components to only include ones referenced by filtered paths
|
||||
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
|
||||
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
||||
|
||||
filtered_schemas = {}
|
||||
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
|
||||
if schema_name in referenced_schemas:
|
||||
filtered_schemas[schema_name] = schema_def
|
||||
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
return filtered_schema
|
||||
|
||||
|
||||
def generate_openapi_spec(output_dir: str, format: str = "yaml", include_examples: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Generate OpenAPI specification using FastAPI's built-in method.
|
||||
|
|
@ -835,12 +1358,63 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
servers=app.servers,
|
||||
)
|
||||
|
||||
# Debug: Check if there's a root-level $defs in the original schema
|
||||
if "$defs" in openapi_schema:
|
||||
print(f"Original schema has root-level $defs with {len(openapi_schema['$defs'])} items")
|
||||
else:
|
||||
print("Original schema has no root-level $defs")
|
||||
|
||||
# Add Llama Stack specific extensions
|
||||
openapi_schema = _add_llama_stack_extensions(openapi_schema, app)
|
||||
|
||||
# Add standard error responses
|
||||
openapi_schema = _add_error_responses(openapi_schema)
|
||||
|
||||
# Ensure all @json_schema_type decorated models are included
|
||||
openapi_schema = _ensure_json_schema_types_included(openapi_schema)
|
||||
|
||||
# Fix $ref references to point to components/schemas instead of $defs
|
||||
openapi_schema = _fix_ref_references(openapi_schema)
|
||||
|
||||
# Debug: Check if there are any $ref references to $defs in the schema
|
||||
defs_refs = []
|
||||
|
||||
def find_defs_refs(obj: Any, path: str = "") -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
||||
defs_refs.append(f"{path}: {obj['$ref']}")
|
||||
for key, value in obj.items():
|
||||
find_defs_refs(value, f"{path}.{key}" if path else key)
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
find_defs_refs(item, f"{path}[{i}]")
|
||||
|
||||
find_defs_refs(openapi_schema)
|
||||
if defs_refs:
|
||||
print(f"Found {len(defs_refs)} $ref references to $defs in schema")
|
||||
for ref in defs_refs[:5]: # Show first 5
|
||||
print(f" {ref}")
|
||||
else:
|
||||
print("No $ref references to $defs found in schema")
|
||||
|
||||
# Note: Let Pydantic/FastAPI generate the correct, standards-compliant schema
|
||||
# Fields with default values should be optional according to OpenAPI standards
|
||||
|
||||
# Fix anyOf schemas with type: 'null' to avoid oasdiff errors
|
||||
openapi_schema = _fix_anyof_with_null(openapi_schema)
|
||||
|
||||
# Fix path parameter resolution issues
|
||||
openapi_schema = _fix_path_parameters(openapi_schema)
|
||||
|
||||
# Eliminate $defs section entirely for oasdiff compatibility
|
||||
openapi_schema = _eliminate_defs_section(openapi_schema)
|
||||
|
||||
# Debug: Check if there's a root-level $defs after flattening
|
||||
if "$defs" in openapi_schema:
|
||||
print(f"After flattening: root-level $defs with {len(openapi_schema['$defs'])} items")
|
||||
else:
|
||||
print("After flattening: no root-level $defs")
|
||||
|
||||
# Ensure all referenced schemas are included
|
||||
# DISABLED: This was using hardcoded schema generation. FastAPI should handle this automatically.
|
||||
# openapi_schema = _ensure_referenced_schemas(openapi_schema)
|
||||
|
|
@ -853,7 +1427,7 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
# DISABLED: This was a hardcoded workaround. Using Pydantic's TypeAdapter instead.
|
||||
# _fix_malformed_schemas(openapi_schema)
|
||||
|
||||
# Split into stable (v1 only), experimental (v1alpha + v1beta), and deprecated specs
|
||||
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
|
||||
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
|
||||
import copy
|
||||
|
||||
|
|
@ -862,6 +1436,16 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True
|
||||
)
|
||||
deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema))
|
||||
combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema))
|
||||
|
||||
# Update title and description for combined schema
|
||||
if "info" in combined_schema:
|
||||
combined_schema["info"]["title"] = "Llama Stack API - Stable & Experimental APIs"
|
||||
combined_schema["info"]["description"] = (
|
||||
combined_schema["info"].get("description", "")
|
||||
+ "\n\n**🔗 COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. "
|
||||
"Use stable APIs for production deployments and experimental APIs for testing new features."
|
||||
)
|
||||
|
||||
# Sort paths alphabetically for stable (v1 only)
|
||||
stable_schema = _sort_paths_alphabetically(stable_schema)
|
||||
|
|
@ -869,11 +1453,24 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
experimental_schema = _sort_paths_alphabetically(experimental_schema)
|
||||
# Sort paths by version prefix for deprecated
|
||||
deprecated_schema = _sort_paths_alphabetically(deprecated_schema)
|
||||
# Sort paths by version prefix for combined (stainless)
|
||||
combined_schema = _sort_paths_alphabetically(combined_schema)
|
||||
|
||||
# Fix schema issues (like exclusiveMinimum -> minimum) for each spec
|
||||
stable_schema = _fix_schema_issues(stable_schema)
|
||||
experimental_schema = _fix_schema_issues(experimental_schema)
|
||||
deprecated_schema = _fix_schema_issues(deprecated_schema)
|
||||
combined_schema = _fix_schema_issues(combined_schema)
|
||||
|
||||
# Validate the schemas
|
||||
print("\n🔍 Validating generated schemas...")
|
||||
stable_valid = validate_openapi_schema(stable_schema, "Stable schema")
|
||||
experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema")
|
||||
deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema")
|
||||
combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema")
|
||||
|
||||
if not all([stable_valid, experimental_valid, deprecated_valid, combined_valid]):
|
||||
print("⚠️ Some schemas failed validation, but continuing with generation...")
|
||||
|
||||
# Add any custom modifications here if needed
|
||||
if include_examples:
|
||||
|
|
@ -887,8 +1484,60 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
# Save the stable specification
|
||||
if format in ["yaml", "both"]:
|
||||
yaml_path = output_path / "llama-stack-spec.yaml"
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.dump(stable_schema, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
# Use ruamel.yaml for better control over YAML serialization
|
||||
try:
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml_writer = YAML()
|
||||
yaml_writer.default_flow_style = False
|
||||
yaml_writer.sort_keys = False
|
||||
yaml_writer.width = 4096 # Prevent line wrapping
|
||||
yaml_writer.allow_unicode = True
|
||||
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml_writer.dump(stable_schema, f)
|
||||
except ImportError:
|
||||
# Fallback to standard yaml if ruamel.yaml is not available
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.dump(stable_schema, f, default_flow_style=False, sort_keys=False)
|
||||
# Post-process the YAML file to remove $defs section and fix references
|
||||
with open(yaml_path) as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
if " $defs:" in yaml_content or "#/$defs/" in yaml_content:
|
||||
print("Post-processing YAML to remove $defs section")
|
||||
|
||||
# Use string replacement to fix references directly
|
||||
if "#/$defs/" in yaml_content:
|
||||
refs_fixed = yaml_content.count("#/$defs/")
|
||||
yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/")
|
||||
print(f"Fixed {refs_fixed} $ref references using string replacement")
|
||||
|
||||
# Parse the YAML content
|
||||
yaml_data = yaml.safe_load(yaml_content)
|
||||
|
||||
# Move $defs to components/schemas if it exists
|
||||
if "$defs" in yaml_data:
|
||||
print(f"Found $defs section with {len(yaml_data['$defs'])} items")
|
||||
if "components" not in yaml_data:
|
||||
yaml_data["components"] = {}
|
||||
if "schemas" not in yaml_data["components"]:
|
||||
yaml_data["components"]["schemas"] = {}
|
||||
|
||||
# Move all $defs to components/schemas
|
||||
for def_name, def_schema in yaml_data["$defs"].items():
|
||||
yaml_data["components"]["schemas"][def_name] = def_schema
|
||||
|
||||
# Remove the $defs section
|
||||
del yaml_data["$defs"]
|
||||
print("Moved $defs to components/schemas")
|
||||
|
||||
# Write the modified YAML back
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False)
|
||||
print("Updated YAML file")
|
||||
|
||||
print(f"✅ Generated YAML (stable): {yaml_path}")
|
||||
|
||||
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
|
||||
|
|
@ -901,6 +1550,25 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
yaml.dump(deprecated_schema, f, default_flow_style=False, sort_keys=False)
|
||||
print(f"✅ Generated YAML (deprecated): {deprecated_yaml_path}")
|
||||
|
||||
# Generate combined (stainless) spec
|
||||
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
|
||||
try:
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml_writer = YAML()
|
||||
yaml_writer.default_flow_style = False
|
||||
yaml_writer.sort_keys = False
|
||||
yaml_writer.width = 4096 # Prevent line wrapping
|
||||
yaml_writer.allow_unicode = True
|
||||
|
||||
with open(stainless_yaml_path, "w") as f:
|
||||
yaml_writer.dump(combined_schema, f)
|
||||
except ImportError:
|
||||
# Fallback to standard yaml if ruamel.yaml is not available
|
||||
with open(stainless_yaml_path, "w") as f:
|
||||
yaml.dump(combined_schema, f, default_flow_style=False, sort_keys=False)
|
||||
print(f"✅ Generated YAML (stainless/combined): {stainless_yaml_path}")
|
||||
|
||||
if format in ["json", "both"]:
|
||||
json_path = output_path / "llama-stack-spec.json"
|
||||
with open(json_path, "w") as f:
|
||||
|
|
@ -917,6 +1585,11 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
json.dump(deprecated_schema, f, indent=2)
|
||||
print(f"✅ Generated JSON (deprecated): {deprecated_json_path}")
|
||||
|
||||
stainless_json_path = output_path / "stainless-llama-stack-spec.json"
|
||||
with open(stainless_json_path, "w") as f:
|
||||
json.dump(combined_schema, f, indent=2)
|
||||
print(f"✅ Generated JSON (stainless/combined): {stainless_json_path}")
|
||||
|
||||
# Generate HTML documentation
|
||||
html_path = output_path / "llama-stack-spec.html"
|
||||
generate_html_docs(stable_schema, html_path)
|
||||
|
|
@ -930,6 +1603,10 @@ def generate_openapi_spec(output_dir: str, format: str = "yaml", include_example
|
|||
generate_html_docs(deprecated_schema, deprecated_html_path, spec_file="deprecated-llama-stack-spec.yaml")
|
||||
print(f"✅ Generated HTML (deprecated): {deprecated_html_path}")
|
||||
|
||||
stainless_html_path = output_path / "stainless-llama-stack-spec.html"
|
||||
generate_html_docs(combined_schema, stainless_html_path, spec_file="stainless-llama-stack-spec.yaml")
|
||||
print(f"✅ Generated HTML (stainless/combined): {stainless_html_path}")
|
||||
|
||||
return stable_schema
|
||||
|
||||
|
||||
|
|
@ -968,9 +1645,55 @@ def main():
|
|||
parser.add_argument("output_dir", help="Output directory for generated files")
|
||||
parser.add_argument("--format", choices=["yaml", "json", "both"], default="yaml", help="Output format")
|
||||
parser.add_argument("--no-examples", action="store_true", help="Exclude examples from the specification")
|
||||
parser.add_argument(
|
||||
"--validate-only", action="store_true", help="Only validate existing schema files, don't generate new ones"
|
||||
)
|
||||
parser.add_argument("--validate-file", help="Validate a specific schema file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle validation-only mode
|
||||
if args.validate_only or args.validate_file:
|
||||
if args.validate_file:
|
||||
# Validate a specific file
|
||||
file_path = Path(args.validate_file)
|
||||
if not file_path.exists():
|
||||
print(f"❌ File not found: {file_path}")
|
||||
return 1
|
||||
|
||||
print(f"🔍 Validating {file_path}...")
|
||||
is_valid = validate_schema_file(file_path)
|
||||
return 0 if is_valid else 1
|
||||
else:
|
||||
# Validate all schema files in output directory
|
||||
output_path = Path(args.output_dir)
|
||||
if not output_path.exists():
|
||||
print(f"❌ Output directory not found: {output_path}")
|
||||
return 1
|
||||
|
||||
print(f"🔍 Validating all schema files in {output_path}...")
|
||||
schema_files = (
|
||||
list(output_path.glob("*.yaml")) + list(output_path.glob("*.yml")) + list(output_path.glob("*.json"))
|
||||
)
|
||||
|
||||
if not schema_files:
|
||||
print("❌ No schema files found to validate")
|
||||
return 1
|
||||
|
||||
all_valid = True
|
||||
for schema_file in schema_files:
|
||||
print(f"\n📄 Validating {schema_file.name}...")
|
||||
is_valid = validate_schema_file(schema_file)
|
||||
if not is_valid:
|
||||
all_valid = False
|
||||
|
||||
if all_valid:
|
||||
print("\n✅ All schema files are valid!")
|
||||
return 0
|
||||
else:
|
||||
print("\n❌ Some schema files failed validation")
|
||||
return 1
|
||||
|
||||
print("🚀 Generating OpenAPI specification using FastAPI...")
|
||||
print(f"📁 Output directory: {args.output_dir}")
|
||||
print(f"📄 Format: {args.format}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue