mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: refactor code to reduce generator script length
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
de4ed29310
commit
3d33291f23
3 changed files with 109 additions and 174 deletions
|
|
@ -1350,6 +1350,69 @@ def _write_yaml_file(file_path: Path, schema: dict[str, Any]) -> None:
|
|||
yaml.dump(schema, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]:
|
||||
"""Get all registered schema names and @json_schema_type decorated model names."""
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
|
||||
registered_schema_names = {info["name"] for info in _registered_schemas.values()}
|
||||
json_schema_type_names = _get_all_json_schema_type_names()
|
||||
return registered_schema_names | json_schema_type_names
|
||||
|
||||
|
||||
def _add_transitive_references(
|
||||
referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Add transitive references for given schemas."""
|
||||
if initial_schemas:
|
||||
referenced_schemas.update(initial_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in initial_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
else:
|
||||
additional_schemas = set()
|
||||
for schema_name in referenced_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
while additional_schemas:
|
||||
new_schemas = additional_schemas - referenced_schemas
|
||||
if not new_schemas:
|
||||
break
|
||||
referenced_schemas.update(new_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in new_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
return referenced_schemas
|
||||
|
||||
|
||||
def _filter_schemas_by_references(
|
||||
filtered_schema: dict[str, Any], filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Filter schemas to only include ones referenced by filtered paths and explicit schemas."""
|
||||
if "components" not in filtered_schema or "schemas" not in filtered_schema["components"]:
|
||||
return filtered_schema
|
||||
|
||||
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
||||
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
||||
explicit_schema_names = _get_explicit_schema_names(openapi_schema)
|
||||
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_schema_names)
|
||||
|
||||
filtered_schemas = {
|
||||
name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas
|
||||
}
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
|
||||
if "components" not in filtered_schema:
|
||||
filtered_schema["components"] = {}
|
||||
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
|
||||
|
||||
return filtered_schema
|
||||
|
||||
|
||||
def _filter_schema_by_version(
|
||||
openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -1369,80 +1432,15 @@ def _filter_schema_by_version(
|
|||
if "paths" not in filtered_schema:
|
||||
return filtered_schema
|
||||
|
||||
# Filter paths based on version prefix and deprecated status
|
||||
filtered_paths = {}
|
||||
for path, path_item in filtered_schema["paths"].items():
|
||||
# Check if path has any deprecated operations
|
||||
is_deprecated = _is_path_deprecated(path_item)
|
||||
|
||||
# Skip deprecated endpoints if exclude_deprecated is True
|
||||
if exclude_deprecated and is_deprecated:
|
||||
if exclude_deprecated and _is_path_deprecated(path_item):
|
||||
continue
|
||||
|
||||
if stable_only:
|
||||
# Only include stable v1 paths, exclude v1alpha and v1beta
|
||||
if _is_stable_path(path):
|
||||
filtered_paths[path] = path_item
|
||||
else:
|
||||
# Only include experimental paths (v1alpha or v1beta), exclude v1
|
||||
if _is_experimental_path(path):
|
||||
filtered_paths[path] = path_item
|
||||
if (stable_only and _is_stable_path(path)) or (not stable_only and _is_experimental_path(path)):
|
||||
filtered_paths[path] = path_item
|
||||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
|
||||
# Filter schemas/components to only include ones referenced by filtered paths
|
||||
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
|
||||
# Find all schemas that are actually referenced by the filtered paths
|
||||
# Use the original schema to find all references, not the filtered one
|
||||
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
||||
|
||||
# Also include all registered schemas and @json_schema_type decorated models
|
||||
# (they should always be included) and all schemas they reference (transitive references)
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
|
||||
# Use the original schema to find registered schema definitions
|
||||
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
||||
registered_schema_names = set()
|
||||
for registration_info in _registered_schemas.values():
|
||||
registered_schema_names.add(registration_info["name"])
|
||||
|
||||
# Also include all @json_schema_type decorated models
|
||||
json_schema_type_names = _get_all_json_schema_type_names()
|
||||
all_explicit_schema_names = registered_schema_names | json_schema_type_names
|
||||
|
||||
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
|
||||
additional_schemas = set()
|
||||
for schema_name in all_explicit_schema_names:
|
||||
referenced_schemas.add(schema_name)
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
# Keep adding transitive references until no new ones are found
|
||||
while additional_schemas:
|
||||
new_schemas = additional_schemas - referenced_schemas
|
||||
if not new_schemas:
|
||||
break
|
||||
referenced_schemas.update(new_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in new_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
# Only keep schemas that are referenced by the filtered paths or are registered/@json_schema_type
|
||||
filtered_schemas = {}
|
||||
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
|
||||
if schema_name in referenced_schemas:
|
||||
filtered_schemas[schema_name] = schema_def
|
||||
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
# Preserve $defs section if it exists
|
||||
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
|
||||
if "components" not in filtered_schema:
|
||||
filtered_schema["components"] = {}
|
||||
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
|
||||
|
||||
return filtered_schema
|
||||
return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
|
||||
|
||||
|
||||
def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]:
|
||||
|
|
@ -1615,50 +1613,7 @@ def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
|||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
|
||||
# Filter schemas/components to only include ones referenced by filtered paths
|
||||
if "components" in filtered_schema and "schemas" in filtered_schema["components"]:
|
||||
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
||||
|
||||
# Also include all registered schemas and @json_schema_type decorated models
|
||||
# (they should always be included) and all schemas they reference (transitive references)
|
||||
from llama_stack.schema_utils import _registered_schemas
|
||||
|
||||
# Use the original schema to find registered schema definitions
|
||||
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
||||
registered_schema_names = set()
|
||||
for registration_info in _registered_schemas.values():
|
||||
registered_schema_names.add(registration_info["name"])
|
||||
|
||||
# Also include all @json_schema_type decorated models
|
||||
json_schema_type_names = _get_all_json_schema_type_names()
|
||||
all_explicit_schema_names = registered_schema_names | json_schema_type_names
|
||||
|
||||
# Find all schemas referenced by registered schemas and @json_schema_type models (transitive)
|
||||
additional_schemas = set()
|
||||
for schema_name in all_explicit_schema_names:
|
||||
referenced_schemas.add(schema_name)
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
# Keep adding transitive references until no new ones are found
|
||||
while additional_schemas:
|
||||
new_schemas = additional_schemas - referenced_schemas
|
||||
if not new_schemas:
|
||||
break
|
||||
referenced_schemas.update(new_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in new_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
filtered_schemas = {}
|
||||
for schema_name, schema_def in filtered_schema["components"]["schemas"].items():
|
||||
if schema_name in referenced_schemas:
|
||||
filtered_schemas[schema_name] = schema_def
|
||||
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
return filtered_schema
|
||||
return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
|
||||
|
||||
|
||||
def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
||||
|
|
@ -1727,7 +1682,6 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema))
|
||||
combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema))
|
||||
|
||||
# Base description for all specs
|
||||
base_description = (
|
||||
"This is the specification of the Llama Stack that provides\n"
|
||||
" a set of endpoints and their corresponding interfaces that are\n"
|
||||
|
|
@ -1735,69 +1689,52 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
|||
" best leverage Llama Models."
|
||||
)
|
||||
|
||||
# Update info section for stable schema
|
||||
if "info" not in stable_schema:
|
||||
stable_schema["info"] = {}
|
||||
stable_schema["info"]["title"] = "Llama Stack Specification"
|
||||
stable_schema["info"]["version"] = "v1"
|
||||
stable_schema["info"]["description"] = (
|
||||
base_description + "\n\n **✅ STABLE**: Production-ready APIs with backward compatibility guarantees."
|
||||
)
|
||||
schema_configs = [
|
||||
(
|
||||
stable_schema,
|
||||
"Llama Stack Specification",
|
||||
"**✅ STABLE**: Production-ready APIs with backward compatibility guarantees.",
|
||||
),
|
||||
(
|
||||
experimental_schema,
|
||||
"Llama Stack Specification - Experimental APIs",
|
||||
"**🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n becoming stable.",
|
||||
),
|
||||
(
|
||||
deprecated_schema,
|
||||
"Llama Stack Specification - Deprecated APIs",
|
||||
"**⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n migration reference only.",
|
||||
),
|
||||
(
|
||||
combined_schema,
|
||||
"Llama Stack Specification - Stable & Experimental APIs",
|
||||
"**🔗 COMBINED**: This specification includes both stable production-ready APIs\n and experimental pre-release APIs. Use stable APIs for production deployments\n and experimental APIs for testing new features.",
|
||||
),
|
||||
]
|
||||
|
||||
# Update info section for experimental schema
|
||||
if "info" not in experimental_schema:
|
||||
experimental_schema["info"] = {}
|
||||
experimental_schema["info"]["title"] = "Llama Stack Specification - Experimental APIs"
|
||||
experimental_schema["info"]["version"] = "v1"
|
||||
experimental_schema["info"]["description"] = (
|
||||
base_description + "\n\n **🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n"
|
||||
" becoming stable."
|
||||
)
|
||||
for schema, title, description_suffix in schema_configs:
|
||||
if "info" not in schema:
|
||||
schema["info"] = {}
|
||||
schema["info"].update(
|
||||
{
|
||||
"title": title,
|
||||
"version": "v1",
|
||||
"description": f"{base_description}\n\n {description_suffix}",
|
||||
}
|
||||
)
|
||||
|
||||
# Update info section for deprecated schema
|
||||
if "info" not in deprecated_schema:
|
||||
deprecated_schema["info"] = {}
|
||||
deprecated_schema["info"]["title"] = "Llama Stack Specification - Deprecated APIs"
|
||||
deprecated_schema["info"]["version"] = "v1"
|
||||
deprecated_schema["info"]["description"] = (
|
||||
base_description + "\n\n **⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n"
|
||||
" migration reference only."
|
||||
)
|
||||
schemas_to_validate = [
|
||||
(stable_schema, "Stable schema"),
|
||||
(experimental_schema, "Experimental schema"),
|
||||
(deprecated_schema, "Deprecated schema"),
|
||||
(combined_schema, "Combined (stainless) schema"),
|
||||
]
|
||||
|
||||
# Update info section for combined schema
|
||||
if "info" not in combined_schema:
|
||||
combined_schema["info"] = {}
|
||||
combined_schema["info"]["title"] = "Llama Stack Specification - Stable & Experimental APIs"
|
||||
combined_schema["info"]["version"] = "v1"
|
||||
combined_schema["info"]["description"] = (
|
||||
base_description + "\n\n\n"
|
||||
" **🔗 COMBINED**: This specification includes both stable production-ready APIs\n"
|
||||
" and experimental pre-release APIs. Use stable APIs for production deployments\n"
|
||||
" and experimental APIs for testing new features."
|
||||
)
|
||||
for schema, _ in schemas_to_validate:
|
||||
_fix_schema_issues(schema)
|
||||
|
||||
# Fix schema issues (like exclusiveMinimum -> minimum) for each spec
|
||||
stable_schema = _fix_schema_issues(stable_schema)
|
||||
experimental_schema = _fix_schema_issues(experimental_schema)
|
||||
deprecated_schema = _fix_schema_issues(deprecated_schema)
|
||||
combined_schema = _fix_schema_issues(combined_schema)
|
||||
|
||||
# Validate the schemas
|
||||
print("\n🔍 Validating generated schemas...")
|
||||
stable_valid = validate_openapi_schema(stable_schema, "Stable schema")
|
||||
experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema")
|
||||
deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema")
|
||||
combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema")
|
||||
|
||||
failed_schemas = []
|
||||
if not stable_valid:
|
||||
failed_schemas.append("Stable schema")
|
||||
if not experimental_valid:
|
||||
failed_schemas.append("Experimental schema")
|
||||
if not deprecated_valid:
|
||||
failed_schemas.append("Deprecated schema")
|
||||
if not combined_valid:
|
||||
failed_schemas.append("Combined (stainless) schema")
|
||||
failed_schemas = [name for schema, name in schemas_to_validate if not validate_openapi_schema(schema, name)]
|
||||
if failed_schemas:
|
||||
raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue