# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. """ Schema discovery and collection for OpenAPI generation. """ from typing import Any def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None: """Ensure components.schemas exists in the schema.""" if "components" not in openapi_schema: openapi_schema["components"] = {} if "schemas" not in openapi_schema["components"]: openapi_schema["components"]["schemas"] = {} def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]) -> None: """ Extract $defs from a schema, move them to components/schemas, and fix references. This handles both TypeAdapter-generated schemas and model_json_schema() schemas. """ if "$defs" in schema: defs = schema.pop("$defs") for def_name, def_schema in defs.items(): if def_name not in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"][def_name] = def_schema # Recursively handle $defs in nested schemas _extract_and_fix_defs(def_schema, openapi_schema) # Fix any references in the main schema that point to $defs def fix_refs_in_schema(obj: Any) -> None: if isinstance(obj, dict): if "$ref" in obj and obj["$ref"].startswith("#/$defs/"): obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/") for value in obj.values(): fix_refs_in_schema(value) elif isinstance(obj, list): for item in obj: fix_refs_in_schema(item) fix_refs_in_schema(schema) def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]: """ Ensure all registered schemas (decorated, explicit, and dynamic) are included in the OpenAPI schema. Relies on llama_stack_api's registry instead of recursively importing every module. """ _ensure_components_schemas(openapi_schema) from pydantic import TypeAdapter from llama_stack_api.schema_utils import ( iter_dynamic_schema_types, iter_json_schema_types, iter_registered_schema_types, ) # Handle explicitly registered schemas first (union types, Annotated structs, etc.) for registration_info in iter_registered_schema_types(): schema_type = registration_info.type schema_name = registration_info.name if schema_name not in openapi_schema["components"]["schemas"]: try: adapter = TypeAdapter(schema_type) schema = adapter.json_schema(ref_template="#/components/schemas/{model}") _extract_and_fix_defs(schema, openapi_schema) openapi_schema["components"]["schemas"][schema_name] = schema except Exception as e: print(f"Warning: Failed to generate schema for registered type {schema_name}: {e}") import traceback traceback.print_exc() continue # Add @json_schema_type decorated models for model in iter_json_schema_types(): schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None) if not schema_name: continue if schema_name not in openapi_schema["components"]["schemas"]: try: if hasattr(model, "model_json_schema"): schema = model.model_json_schema(ref_template="#/components/schemas/{model}") else: adapter = TypeAdapter(model) schema = adapter.json_schema(ref_template="#/components/schemas/{model}") _extract_and_fix_defs(schema, openapi_schema) openapi_schema["components"]["schemas"][schema_name] = schema except Exception as e: print(f"Warning: Failed to generate schema for {schema_name}: {e}") continue # Include any dynamic models generated while building endpoints for model in iter_dynamic_schema_types(): try: schema_name = model.__name__ if schema_name not in openapi_schema["components"]["schemas"]: schema = model.model_json_schema(ref_template="#/components/schemas/{model}") _extract_and_fix_defs(schema, openapi_schema) openapi_schema["components"]["schemas"][schema_name] = schema except Exception: continue return openapi_schema