chore: refactor code to reduce generator script length

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-12 12:12:23 +01:00
parent de4ed29310
commit 3d33291f23
No known key found for this signature in database
3 changed files with 109 additions and 174 deletions

View file

@ -7,7 +7,6 @@ info:
tailored to tailored to
best leverage Llama Models. best leverage Llama Models.
**🔗 COMBINED**: This specification includes both stable production-ready APIs **🔗 COMBINED**: This specification includes both stable production-ready APIs
and experimental pre-release APIs. Use stable APIs for production deployments and experimental pre-release APIs. Use stable APIs for production deployments
and experimental APIs for testing new features. and experimental APIs for testing new features.

View file

@ -7,7 +7,6 @@ info:
tailored to tailored to
best leverage Llama Models. best leverage Llama Models.
**🔗 COMBINED**: This specification includes both stable production-ready APIs **🔗 COMBINED**: This specification includes both stable production-ready APIs
and experimental pre-release APIs. Use stable APIs for production deployments and experimental pre-release APIs. Use stable APIs for production deployments
and experimental APIs for testing new features. and experimental APIs for testing new features.

View file

@ -1350,6 +1350,69 @@ def _write_yaml_file(file_path: Path, schema: dict[str, Any]) -> None:
yaml.dump(schema, f, default_flow_style=False, sort_keys=False) yaml.dump(schema, f, default_flow_style=False, sort_keys=False)
def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]:
"""Get all registered schema names and @json_schema_type decorated model names."""
from llama_stack.schema_utils import _registered_schemas
registered_schema_names = {info["name"] for info in _registered_schemas.values()}
json_schema_type_names = _get_all_json_schema_type_names()
return registered_schema_names | json_schema_type_names
def _add_transitive_references(
referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None
) -> set[str]:
"""Add transitive references for given schemas."""
if initial_schemas:
referenced_schemas.update(initial_schemas)
additional_schemas = set()
for schema_name in initial_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
else:
additional_schemas = set()
for schema_name in referenced_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
while additional_schemas:
new_schemas = additional_schemas - referenced_schemas
if not new_schemas:
break
referenced_schemas.update(new_schemas)
additional_schemas = set()
for schema_name in new_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
return referenced_schemas
def _filter_schemas_by_references(
filtered_schema: dict[str, Any], filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]
) -> dict[str, Any]:
"""Filter schemas to only include ones referenced by filtered paths and explicit schemas."""
if "components" not in filtered_schema or "schemas" not in filtered_schema["components"]:
return filtered_schema
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
explicit_schema_names = _get_explicit_schema_names(openapi_schema)
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_schema_names)
filtered_schemas = {
name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas
}
filtered_schema["components"]["schemas"] = filtered_schemas
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
if "components" not in filtered_schema:
filtered_schema["components"] = {}
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
return filtered_schema
def _filter_schema_by_version( def _filter_schema_by_version(
openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -1369,80 +1432,15 @@ def _filter_schema_by_version(
if "paths" not in filtered_schema: if "paths" not in filtered_schema:
return filtered_schema return filtered_schema
# Filter paths based on version prefix and deprecated status
filtered_paths = {} filtered_paths = {}
for path, path_item in filtered_schema["paths"].items(): for path, path_item in filtered_schema["paths"].items():
# Check if path has any deprecated operations if exclude_deprecated and _is_path_deprecated(path_item):
is_deprecated = _is_path_deprecated(path_item)
# Skip deprecated endpoints if exclude_deprecated is True
if exclude_deprecated and is_deprecated:
continue continue
if (stable_only and _is_stable_path(path)) or (not stable_only and _is_experimental_path(path)):
if stable_only:
# Only include stable v1 paths, exclude v1alpha and v1beta
if _is_stable_path(path):
filtered_paths[path] = path_item
else:
# Only include experimental paths (v1alpha or v1beta), exclude v1
if _is_experimental_path(path):
filtered_paths[path] = path_item filtered_paths[path] = path_item
filtered_schema["paths"] = filtered_paths filtered_schema["paths"] = filtered_paths
return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
# Filter schemas/components to only include ones referenced by filtered paths
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
# Find all schemas that are actually referenced by the filtered paths
# Use the original schema to find all references, not the filtered one
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
# Also include all registered schemas and @json_schema_type decorated models
# (they should always be included) and all schemas they reference (transitive references)
from llama_stack.schema_utils import _registered_schemas
# Use the original schema to find registered schema definitions
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
registered_schema_names = set()
for registration_info in _registered_schemas.values():
registered_schema_names.add(registration_info["name"])
# Also include all @json_schema_type decorated models
json_schema_type_names = _get_all_json_schema_type_names()
all_explicit_schema_names = registered_schema_names | json_schema_type_names
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
additional_schemas = set()
for schema_name in all_explicit_schema_names:
referenced_schemas.add(schema_name)
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
# Keep adding transitive references until no new ones are found
while additional_schemas:
new_schemas = additional_schemas - referenced_schemas
if not new_schemas:
break
referenced_schemas.update(new_schemas)
additional_schemas = set()
for schema_name in new_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
# Only keep schemas that are referenced by the filtered paths or are registered/@json_schema_type
filtered_schemas = {}
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
if schema_name in referenced_schemas:
filtered_schemas[schema_name] = schema_def
filtered_schema["components"]["schemas"] = filtered_schemas
# Preserve $defs section if it exists
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
if "components" not in filtered_schema:
filtered_schema["components"] = {}
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
return filtered_schema
def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]: def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]:
@ -1615,50 +1613,7 @@ def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
filtered_schema["paths"] = filtered_paths filtered_schema["paths"] = filtered_paths
# Filter schemas/components to only include ones referenced by filtered paths return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
# Also include all registered schemas and @json_schema_type decorated models
# (they should always be included) and all schemas they reference (transitive references)
from llama_stack.schema_utils import _registered_schemas
# Use the original schema to find registered schema definitions
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
registered_schema_names = set()
for registration_info in _registered_schemas.values():
registered_schema_names.add(registration_info["name"])
# Also include all @json_schema_type decorated models
json_schema_type_names = _get_all_json_schema_type_names()
all_explicit_schema_names = registered_schema_names | json_schema_type_names
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
additional_schemas = set()
for schema_name in all_explicit_schema_names:
referenced_schemas.add(schema_name)
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
# Keep adding transitive references until no new ones are found
while additional_schemas:
new_schemas = additional_schemas - referenced_schemas
if not new_schemas:
break
referenced_schemas.update(new_schemas)
additional_schemas = set()
for schema_name in new_schemas:
if schema_name in all_schemas:
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
filtered_schemas = {}
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
if schema_name in referenced_schemas:
filtered_schemas[schema_name] = schema_def
filtered_schema["components"]["schemas"] = filtered_schemas
return filtered_schema
def generate_openapi_spec(output_dir: str) -> dict[str, Any]: def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
@ -1727,7 +1682,6 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema)) deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema))
combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema)) combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema))
# Base description for all specs
base_description = ( base_description = (
"This is the specification of the Llama Stack that provides\n" "This is the specification of the Llama Stack that provides\n"
" a set of endpoints and their corresponding interfaces that are\n" " a set of endpoints and their corresponding interfaces that are\n"
@ -1735,69 +1689,52 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
" best leverage Llama Models." " best leverage Llama Models."
) )
# Update info section for stable schema schema_configs = [
if "info" not in stable_schema: (
stable_schema["info"] = {} stable_schema,
stable_schema["info"]["title"] = "Llama Stack Specification" "Llama Stack Specification",
stable_schema["info"]["version"] = "v1" "**✅ STABLE**: Production-ready APIs with backward compatibility guarantees.",
stable_schema["info"]["description"] = ( ),
base_description + "\n\n **✅ STABLE**: Production-ready APIs with backward compatibility guarantees." (
experimental_schema,
"Llama Stack Specification - Experimental APIs",
"**🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n becoming stable.",
),
(
deprecated_schema,
"Llama Stack Specification - Deprecated APIs",
"**⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n migration reference only.",
),
(
combined_schema,
"Llama Stack Specification - Stable & Experimental APIs",
"**🔗 COMBINED**: This specification includes both stable production-ready APIs\n and experimental pre-release APIs. Use stable APIs for production deployments\n and experimental APIs for testing new features.",
),
]
for schema, title, description_suffix in schema_configs:
if "info" not in schema:
schema["info"] = {}
schema["info"].update(
{
"title": title,
"version": "v1",
"description": f"{base_description}\n\n {description_suffix}",
}
) )
# Update info section for experimental schema schemas_to_validate = [
if "info" not in experimental_schema: (stable_schema, "Stable schema"),
experimental_schema["info"] = {} (experimental_schema, "Experimental schema"),
experimental_schema["info"]["title"] = "Llama Stack Specification - Experimental APIs" (deprecated_schema, "Deprecated schema"),
experimental_schema["info"]["version"] = "v1" (combined_schema, "Combined (stainless) schema"),
experimental_schema["info"]["description"] = ( ]
base_description + "\n\n **🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n"
" becoming stable."
)
# Update info section for deprecated schema for schema, _ in schemas_to_validate:
if "info" not in deprecated_schema: _fix_schema_issues(schema)
deprecated_schema["info"] = {}
deprecated_schema["info"]["title"] = "Llama Stack Specification - Deprecated APIs"
deprecated_schema["info"]["version"] = "v1"
deprecated_schema["info"]["description"] = (
base_description + "\n\n **⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n"
" migration reference only."
)
# Update info section for combined schema
if "info" not in combined_schema:
combined_schema["info"] = {}
combined_schema["info"]["title"] = "Llama Stack Specification - Stable & Experimental APIs"
combined_schema["info"]["version"] = "v1"
combined_schema["info"]["description"] = (
base_description + "\n\n\n"
" **🔗 COMBINED**: This specification includes both stable production-ready APIs\n"
" and experimental pre-release APIs. Use stable APIs for production deployments\n"
" and experimental APIs for testing new features."
)
# Fix schema issues (like exclusiveMinimum -> minimum) for each spec
stable_schema = _fix_schema_issues(stable_schema)
experimental_schema = _fix_schema_issues(experimental_schema)
deprecated_schema = _fix_schema_issues(deprecated_schema)
combined_schema = _fix_schema_issues(combined_schema)
# Validate the schemas
print("\n🔍 Validating generated schemas...") print("\n🔍 Validating generated schemas...")
stable_valid = validate_openapi_schema(stable_schema, "Stable schema") failed_schemas = [name for schema, name in schemas_to_validate if not validate_openapi_schema(schema, name)]
experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema")
deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema")
combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema")
failed_schemas = []
if not stable_valid:
failed_schemas.append("Stable schema")
if not experimental_valid:
failed_schemas.append("Experimental schema")
if not deprecated_valid:
failed_schemas.append("Deprecated schema")
if not combined_valid:
failed_schemas.append("Combined (stainless) schema")
if failed_schemas: if failed_schemas:
raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}") raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}")