mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
Schema discovery and collection for OpenAPI generation.
|
|
"""
|
|
|
|
import importlib
|
|
from typing import Any
|
|
|
|
|
|
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
|
"""Ensure components.schemas exists in the schema."""
|
|
if "components" not in openapi_schema:
|
|
openapi_schema["components"] = {}
|
|
if "schemas" not in openapi_schema["components"]:
|
|
openapi_schema["components"]["schemas"] = {}
|
|
|
|
|
|
def _load_extra_schema_modules() -> None:
|
|
"""
|
|
Import modules outside llama_stack_api that use schema_utils to register schemas.
|
|
|
|
The API package already imports its submodules via __init__, but server-side modules
|
|
like telemetry need to be imported explicitly so their decorator side effects run.
|
|
"""
|
|
extra_modules = [
|
|
"llama_stack.core.telemetry.telemetry",
|
|
]
|
|
for module_name in extra_modules:
|
|
try:
|
|
importlib.import_module(module_name)
|
|
except ImportError:
|
|
continue
|
|
|
|
|
|
def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]) -> None:
|
|
"""
|
|
Extract $defs from a schema, move them to components/schemas, and fix references.
|
|
This handles both TypeAdapter-generated schemas and model_json_schema() schemas.
|
|
"""
|
|
if "$defs" in schema:
|
|
defs = schema.pop("$defs")
|
|
for def_name, def_schema in defs.items():
|
|
if def_name not in openapi_schema["components"]["schemas"]:
|
|
openapi_schema["components"]["schemas"][def_name] = def_schema
|
|
# Recursively handle $defs in nested schemas
|
|
_extract_and_fix_defs(def_schema, openapi_schema)
|
|
|
|
# Fix any references in the main schema that point to $defs
|
|
def fix_refs_in_schema(obj: Any) -> None:
|
|
if isinstance(obj, dict):
|
|
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
|
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
|
for value in obj.values():
|
|
fix_refs_in_schema(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
fix_refs_in_schema(item)
|
|
|
|
fix_refs_in_schema(schema)
|
|
|
|
|
|
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Ensure all registered schemas (decorated, explicit, and dynamic) are included in the OpenAPI schema.
|
|
Relies on llama_stack_api's registry instead of recursively importing every module.
|
|
"""
|
|
_ensure_components_schemas(openapi_schema)
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack_api.schema_utils import (
|
|
iter_dynamic_schema_types,
|
|
iter_json_schema_types,
|
|
iter_registered_schema_types,
|
|
)
|
|
|
|
# Import extra modules (e.g., telemetry) whose schema registrations live outside llama_stack_api
|
|
_load_extra_schema_modules()
|
|
|
|
# Handle explicitly registered schemas first (union types, Annotated structs, etc.)
|
|
for registration_info in iter_registered_schema_types():
|
|
schema_type = registration_info.type
|
|
schema_name = registration_info.name
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
try:
|
|
adapter = TypeAdapter(schema_type)
|
|
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception as e:
|
|
print(f"Warning: Failed to generate schema for registered type {schema_name}: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
# Add @json_schema_type decorated models
|
|
for model in iter_json_schema_types():
|
|
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
|
|
if not schema_name:
|
|
continue
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
try:
|
|
if hasattr(model, "model_json_schema"):
|
|
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
|
else:
|
|
adapter = TypeAdapter(model)
|
|
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception as e:
|
|
print(f"Warning: Failed to generate schema for {schema_name}: {e}")
|
|
continue
|
|
|
|
# Include any dynamic models generated while building endpoints
|
|
for model in iter_dynamic_schema_types():
|
|
try:
|
|
schema_name = model.__name__
|
|
if schema_name not in openapi_schema["components"]["schemas"]:
|
|
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
|
_extract_and_fix_defs(schema, openapi_schema)
|
|
openapi_schema["components"]["schemas"][schema_name] = schema
|
|
except Exception:
|
|
continue
|
|
|
|
return openapi_schema
|