From 3d33291f2359fa27ba4331eb863832395f2768f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 12 Nov 2025 12:12:23 +0100 Subject: [PATCH] chore: refactor code to reduce generator script length MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sรฉbastien Han --- client-sdks/stainless/openapi.yml | 1 - docs/static/stainless-llama-stack-spec.yaml | 1 - scripts/fastapi_generator.py | 281 ++++++++------------ 3 files changed, 109 insertions(+), 174 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index db194dc2e..4ffe1de17 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -7,7 +7,6 @@ info: tailored to best leverage Llama Models. - **๐Ÿ”— COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. Use stable APIs for production deployments and experimental APIs for testing new features. diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index ed9f1fe78..f45ea0e82 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -7,7 +7,6 @@ info: tailored to best leverage Llama Models. - **๐Ÿ”— COMBINED**: This specification includes both stable production-ready APIs and experimental pre-release APIs. Use stable APIs for production deployments and experimental APIs for testing new features. diff --git a/scripts/fastapi_generator.py b/scripts/fastapi_generator.py index d3b3e590f..4bc38b9f1 100755 --- a/scripts/fastapi_generator.py +++ b/scripts/fastapi_generator.py @@ -1350,6 +1350,69 @@ def _write_yaml_file(file_path: Path, schema: dict[str, Any]) -> None: yaml.dump(schema, f, default_flow_style=False, sort_keys=False) +def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]: + """Get all registered schema names and @json_schema_type decorated model names.""" + from llama_stack.schema_utils import _registered_schemas + + registered_schema_names = {info["name"] for info in _registered_schemas.values()} + json_schema_type_names = _get_all_json_schema_type_names() + return registered_schema_names | json_schema_type_names + + +def _add_transitive_references( + referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None +) -> set[str]: + """Add transitive references for given schemas.""" + if initial_schemas: + referenced_schemas.update(initial_schemas) + additional_schemas = set() + for schema_name in initial_schemas: + if schema_name in all_schemas: + additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) + else: + additional_schemas = set() + for schema_name in referenced_schemas: + if schema_name in all_schemas: + additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) + + while additional_schemas: + new_schemas = additional_schemas - referenced_schemas + if not new_schemas: + break + referenced_schemas.update(new_schemas) + additional_schemas = set() + for schema_name in new_schemas: + if schema_name in all_schemas: + additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) + + return referenced_schemas + + +def _filter_schemas_by_references( + filtered_schema: dict[str, Any], filtered_paths: dict[str, Any], openapi_schema: dict[str, Any] +) -> dict[str, Any]: + """Filter schemas to only include ones referenced by filtered paths and explicit schemas.""" + if "components" not in filtered_schema or "schemas" not in filtered_schema["components"]: + return filtered_schema + + referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) + all_schemas = openapi_schema.get("components", {}).get("schemas", {}) + explicit_schema_names = _get_explicit_schema_names(openapi_schema) + referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_schema_names) + + filtered_schemas = { + name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas + } + filtered_schema["components"]["schemas"] = filtered_schemas + + if "components" in openapi_schema and "$defs" in openapi_schema["components"]: + if "components" not in filtered_schema: + filtered_schema["components"] = {} + filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"] + + return filtered_schema + + def _filter_schema_by_version( openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True ) -> dict[str, Any]: @@ -1369,80 +1432,15 @@ def _filter_schema_by_version( if "paths" not in filtered_schema: return filtered_schema - # Filter paths based on version prefix and deprecated status filtered_paths = {} for path, path_item in filtered_schema["paths"].items(): - # Check if path has any deprecated operations - is_deprecated = _is_path_deprecated(path_item) - - # Skip deprecated endpoints if exclude_deprecated is True - if exclude_deprecated and is_deprecated: + if exclude_deprecated and _is_path_deprecated(path_item): continue - - if stable_only: - # Only include stable v1 paths, exclude v1alpha and v1beta - if _is_stable_path(path): - filtered_paths[path] = path_item - else: - # Only include experimental paths (v1alpha or v1beta), exclude v1 - if _is_experimental_path(path): - filtered_paths[path] = path_item + if (stable_only and _is_stable_path(path)) or (not stable_only and _is_experimental_path(path)): + filtered_paths[path] = path_item filtered_schema["paths"] = filtered_paths - - # Filter schemas/components to only include ones referenced by filtered paths - if "components" in filtered_schema and "schemas" in filtered_schema["components"]: - # Find all schemas that are actually referenced by the filtered paths - # Use the original schema to find all references, not the filtered one - referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) - - # Also include all registered schemas and @json_schema_type decorated models - # (they should always be included) and all schemas they reference (transitive references) - from llama_stack.schema_utils import _registered_schemas - - # Use the original schema to find registered schema definitions - all_schemas = openapi_schema.get("components", {}).get("schemas", {}) - registered_schema_names = set() - for registration_info in _registered_schemas.values(): - registered_schema_names.add(registration_info["name"]) - - # Also include all @json_schema_type decorated models - json_schema_type_names = _get_all_json_schema_type_names() - all_explicit_schema_names = registered_schema_names | json_schema_type_names - - # Find all schemas referenced by registered schemas and @json_schema_type models (transitive) - additional_schemas = set() - for schema_name in all_explicit_schema_names: - referenced_schemas.add(schema_name) - if schema_name in all_schemas: - additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) - - # Keep adding transitive references until no new ones are found - while additional_schemas: - new_schemas = additional_schemas - referenced_schemas - if not new_schemas: - break - referenced_schemas.update(new_schemas) - additional_schemas = set() - for schema_name in new_schemas: - if schema_name in all_schemas: - additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) - - # Only keep schemas that are referenced by the filtered paths or are registered/@json_schema_type - filtered_schemas = {} - for schema_name, schema_def in filtered_schema["components"]["schemas"].items(): - if schema_name in referenced_schemas: - filtered_schemas[schema_name] = schema_def - - filtered_schema["components"]["schemas"] = filtered_schemas - - # Preserve $defs section if it exists - if "components" in openapi_schema and "$defs" in openapi_schema["components"]: - if "components" not in filtered_schema: - filtered_schema["components"] = {} - filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"] - - return filtered_schema + return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema) def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]: @@ -1615,50 +1613,7 @@ def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]: filtered_schema["paths"] = filtered_paths - # Filter schemas/components to only include ones referenced by filtered paths - if "components" in filtered_schema and "schemas" in filtered_schema["components"]: - referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema) - - # Also include all registered schemas and @json_schema_type decorated models - # (they should always be included) and all schemas they reference (transitive references) - from llama_stack.schema_utils import _registered_schemas - - # Use the original schema to find registered schema definitions - all_schemas = openapi_schema.get("components", {}).get("schemas", {}) - registered_schema_names = set() - for registration_info in _registered_schemas.values(): - registered_schema_names.add(registration_info["name"]) - - # Also include all @json_schema_type decorated models - json_schema_type_names = _get_all_json_schema_type_names() - all_explicit_schema_names = registered_schema_names | json_schema_type_names - - # Find all schemas referenced by registered schemas and @json_schema_type models (transitive) - additional_schemas = set() - for schema_name in all_explicit_schema_names: - referenced_schemas.add(schema_name) - if schema_name in all_schemas: - additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) - - # Keep adding transitive references until no new ones are found - while additional_schemas: - new_schemas = additional_schemas - referenced_schemas - if not new_schemas: - break - referenced_schemas.update(new_schemas) - additional_schemas = set() - for schema_name in new_schemas: - if schema_name in all_schemas: - additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name])) - - filtered_schemas = {} - for schema_name, schema_def in filtered_schema["components"]["schemas"].items(): - if schema_name in referenced_schemas: - filtered_schemas[schema_name] = schema_def - - filtered_schema["components"]["schemas"] = filtered_schemas - - return filtered_schema + return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema) def generate_openapi_spec(output_dir: str) -> dict[str, Any]: @@ -1727,7 +1682,6 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]: deprecated_schema = _filter_deprecated_schema(copy.deepcopy(openapi_schema)) combined_schema = _filter_combined_schema(copy.deepcopy(openapi_schema)) - # Base description for all specs base_description = ( "This is the specification of the Llama Stack that provides\n" " a set of endpoints and their corresponding interfaces that are\n" @@ -1735,69 +1689,52 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]: " best leverage Llama Models." ) - # Update info section for stable schema - if "info" not in stable_schema: - stable_schema["info"] = {} - stable_schema["info"]["title"] = "Llama Stack Specification" - stable_schema["info"]["version"] = "v1" - stable_schema["info"]["description"] = ( - base_description + "\n\n **โœ… STABLE**: Production-ready APIs with backward compatibility guarantees." - ) + schema_configs = [ + ( + stable_schema, + "Llama Stack Specification", + "**โœ… STABLE**: Production-ready APIs with backward compatibility guarantees.", + ), + ( + experimental_schema, + "Llama Stack Specification - Experimental APIs", + "**๐Ÿงช EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n becoming stable.", + ), + ( + deprecated_schema, + "Llama Stack Specification - Deprecated APIs", + "**โš ๏ธ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n migration reference only.", + ), + ( + combined_schema, + "Llama Stack Specification - Stable & Experimental APIs", + "**๐Ÿ”— COMBINED**: This specification includes both stable production-ready APIs\n and experimental pre-release APIs. Use stable APIs for production deployments\n and experimental APIs for testing new features.", + ), + ] - # Update info section for experimental schema - if "info" not in experimental_schema: - experimental_schema["info"] = {} - experimental_schema["info"]["title"] = "Llama Stack Specification - Experimental APIs" - experimental_schema["info"]["version"] = "v1" - experimental_schema["info"]["description"] = ( - base_description + "\n\n **๐Ÿงช EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n" - " becoming stable." - ) + for schema, title, description_suffix in schema_configs: + if "info" not in schema: + schema["info"] = {} + schema["info"].update( + { + "title": title, + "version": "v1", + "description": f"{base_description}\n\n {description_suffix}", + } + ) - # Update info section for deprecated schema - if "info" not in deprecated_schema: - deprecated_schema["info"] = {} - deprecated_schema["info"]["title"] = "Llama Stack Specification - Deprecated APIs" - deprecated_schema["info"]["version"] = "v1" - deprecated_schema["info"]["description"] = ( - base_description + "\n\n **โš ๏ธ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n" - " migration reference only." - ) + schemas_to_validate = [ + (stable_schema, "Stable schema"), + (experimental_schema, "Experimental schema"), + (deprecated_schema, "Deprecated schema"), + (combined_schema, "Combined (stainless) schema"), + ] - # Update info section for combined schema - if "info" not in combined_schema: - combined_schema["info"] = {} - combined_schema["info"]["title"] = "Llama Stack Specification - Stable & Experimental APIs" - combined_schema["info"]["version"] = "v1" - combined_schema["info"]["description"] = ( - base_description + "\n\n\n" - " **๐Ÿ”— COMBINED**: This specification includes both stable production-ready APIs\n" - " and experimental pre-release APIs. Use stable APIs for production deployments\n" - " and experimental APIs for testing new features." - ) + for schema, _ in schemas_to_validate: + _fix_schema_issues(schema) - # Fix schema issues (like exclusiveMinimum -> minimum) for each spec - stable_schema = _fix_schema_issues(stable_schema) - experimental_schema = _fix_schema_issues(experimental_schema) - deprecated_schema = _fix_schema_issues(deprecated_schema) - combined_schema = _fix_schema_issues(combined_schema) - - # Validate the schemas print("\n๐Ÿ” Validating generated schemas...") - stable_valid = validate_openapi_schema(stable_schema, "Stable schema") - experimental_valid = validate_openapi_schema(experimental_schema, "Experimental schema") - deprecated_valid = validate_openapi_schema(deprecated_schema, "Deprecated schema") - combined_valid = validate_openapi_schema(combined_schema, "Combined (stainless) schema") - - failed_schemas = [] - if not stable_valid: - failed_schemas.append("Stable schema") - if not experimental_valid: - failed_schemas.append("Experimental schema") - if not deprecated_valid: - failed_schemas.append("Deprecated schema") - if not combined_valid: - failed_schemas.append("Combined (stainless) schema") + failed_schemas = [name for schema, name in schemas_to_validate if not validate_openapi_schema(schema, name)] if failed_schemas: raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}")