mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
even more cleanup, the deltas should be much smaller now
This commit is contained in:
parent
5293b4e5e9
commit
9deb0beb86
14 changed files with 5038 additions and 17435 deletions
|
|
@ -9,11 +9,8 @@ Schema discovery and collection for OpenAPI generation.
|
|||
"""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import Any
|
||||
|
||||
from .state import _dynamic_models
|
||||
|
||||
|
||||
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
||||
"""Ensure components.schemas exists in the schema."""
|
||||
|
|
@ -23,54 +20,21 @@ def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
|||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
|
||||
def _import_all_modules_in_package(package_name: str) -> list[Any]:
|
||||
def _load_extra_schema_modules() -> None:
|
||||
"""
|
||||
Dynamically import all modules in a package to trigger register_schema calls.
|
||||
Import modules outside llama_stack_api that use schema_utils to register schemas.
|
||||
|
||||
This walks through all modules in the package and imports them, ensuring
|
||||
that any register_schema() calls at module level are executed.
|
||||
|
||||
Args:
|
||||
package_name: The fully qualified package name (e.g., 'llama_stack_api')
|
||||
|
||||
Returns:
|
||||
List of imported module objects
|
||||
The API package already imports its submodules via __init__, but server-side modules
|
||||
like telemetry need to be imported explicitly so their decorator side effects run.
|
||||
"""
|
||||
modules = []
|
||||
try:
|
||||
package = importlib.import_module(package_name)
|
||||
except ImportError:
|
||||
return modules
|
||||
|
||||
package_path = getattr(package, "__path__", None)
|
||||
if not package_path:
|
||||
return modules
|
||||
|
||||
# Walk packages and modules recursively
|
||||
for _, modname, ispkg in pkgutil.walk_packages(package_path, prefix=f"{package_name}."):
|
||||
if not modname.startswith("_"):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
modules.append(module)
|
||||
|
||||
# If this is a package, also try to import any .py files directly
|
||||
# (e.g., llama_stack_api.scoring_functions.scoring_functions)
|
||||
if ispkg:
|
||||
try:
|
||||
# Try importing the module file with the same name as the package
|
||||
# This handles cases like scoring_functions/scoring_functions.py
|
||||
module_file_name = f"{modname}.{modname.split('.')[-1]}"
|
||||
module_file = importlib.import_module(module_file_name)
|
||||
if module_file not in modules:
|
||||
modules.append(module_file)
|
||||
except (ImportError, AttributeError, TypeError):
|
||||
# It's okay if this fails - not all packages have a module file with the same name
|
||||
pass
|
||||
except (ImportError, AttributeError, TypeError):
|
||||
# Skip modules that can't be imported (e.g., missing dependencies)
|
||||
continue
|
||||
|
||||
return modules
|
||||
extra_modules = [
|
||||
"llama_stack.core.telemetry.telemetry",
|
||||
]
|
||||
for module_name in extra_modules:
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
|
||||
def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]) -> None:
|
||||
|
|
@ -102,82 +66,66 @@ def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]
|
|||
|
||||
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Ensure all @json_schema_type decorated models and registered schemas are included in the OpenAPI schema.
|
||||
This finds all models with the _llama_stack_schema_type attribute and schemas registered via register_schema.
|
||||
Ensure all registered schemas (decorated, explicit, and dynamic) are included in the OpenAPI schema.
|
||||
Relies on llama_stack_api's registry instead of recursively importing every module.
|
||||
"""
|
||||
_ensure_components_schemas(openapi_schema)
|
||||
|
||||
# Import TypeAdapter for handling union types and other non-model types
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
# Dynamically import all modules in packages that might register schemas
|
||||
# This ensures register_schema() calls execute and populate _registered_schemas
|
||||
# Also collect the modules for later scanning of @json_schema_type decorated classes
|
||||
apis_modules = _import_all_modules_in_package("llama_stack_api")
|
||||
_import_all_modules_in_package("llama_stack.core.telemetry")
|
||||
from llama_stack_api.schema_utils import (
|
||||
iter_dynamic_schema_types,
|
||||
iter_json_schema_types,
|
||||
iter_registered_schema_types,
|
||||
)
|
||||
|
||||
# First, handle registered schemas (union types, etc.)
|
||||
from llama_stack_api.schema_utils import _registered_schemas
|
||||
# Import extra modules (e.g., telemetry) whose schema registrations live outside llama_stack_api
|
||||
_load_extra_schema_modules()
|
||||
|
||||
for schema_type, registration_info in _registered_schemas.items():
|
||||
schema_name = registration_info["name"]
|
||||
# Handle explicitly registered schemas first (union types, Annotated structs, etc.)
|
||||
for registration_info in iter_registered_schema_types():
|
||||
schema_type = registration_info.type
|
||||
schema_name = registration_info.name
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
# Use TypeAdapter for union types and other non-model types
|
||||
# Use ref_template to generate references in the format we need
|
||||
adapter = TypeAdapter(schema_type)
|
||||
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
|
||||
# Extract and fix $defs if present
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception as e:
|
||||
# Skip if we can't generate the schema
|
||||
print(f"Warning: Failed to generate schema for registered type {schema_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Find all classes with the _llama_stack_schema_type attribute
|
||||
# Use the modules we already imported above
|
||||
for module in apis_modules:
|
||||
for attr_name in dir(module):
|
||||
# Add @json_schema_type decorated models
|
||||
for model in iter_json_schema_types():
|
||||
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
|
||||
if not schema_name:
|
||||
continue
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
attr = getattr(module, attr_name)
|
||||
if (
|
||||
hasattr(attr, "_llama_stack_schema_type")
|
||||
and hasattr(attr, "model_json_schema")
|
||||
and hasattr(attr, "__name__")
|
||||
):
|
||||
schema_name = attr.__name__
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
# Use ref_template to ensure consistent reference format and $defs handling
|
||||
schema = attr.model_json_schema(ref_template="#/components/schemas/{model}")
|
||||
# Extract and fix $defs if present (model_json_schema can also generate $defs)
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception as e:
|
||||
# Skip if we can't generate the schema
|
||||
print(f"Warning: Failed to generate schema for {schema_name}: {e}")
|
||||
continue
|
||||
except (AttributeError, TypeError):
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
||||
else:
|
||||
adapter = TypeAdapter(model)
|
||||
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to generate schema for {schema_name}: {e}")
|
||||
continue
|
||||
|
||||
# Also include any dynamic models that were created during endpoint generation
|
||||
# This is a workaround to ensure dynamic models appear in the schema
|
||||
for model in _dynamic_models:
|
||||
# Include any dynamic models generated while building endpoints
|
||||
for model in iter_dynamic_schema_types():
|
||||
try:
|
||||
schema_name = model.__name__
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
||||
# Extract and fix $defs if present
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception:
|
||||
# Skip if we can't generate the schema
|
||||
continue
|
||||
|
||||
return openapi_schema
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue