llama-stack-mirror/scripts/distro_codegen.py
Ashwin Bharambe 9f87d67849 fixes
2025-11-14 15:29:31 -08:00

161 lines
5.1 KiB
Python
Executable file

#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
#
# CI can direct the generated artifacts into an alternate checkout by passing
# --repo-root, allowing the trusted copy of this script to run from a separate
# worktree.
import argparse
import concurrent.futures
import importlib
import subprocess
import sys
from collections.abc import Iterable
from functools import partial
from pathlib import Path
from rich.progress import Progress, SpinnerColumn, TextColumn
_DEFAULT_REPO_ROOT = Path(__file__).parent.parent
REPO_ROOT = _DEFAULT_REPO_ROOT
def set_repo_root(repo_root: Path) -> None:
"""Update the global repository root used by helper functions."""
global REPO_ROOT
REPO_ROOT = repo_root
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate distribution docs and YAML artifacts."
)
parser.add_argument(
"--repo-root",
type=Path,
default=_DEFAULT_REPO_ROOT,
help="Repository root where generated artifacts should be written.",
)
return parser.parse_args()
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def find_distro_dirs(distro_dir: Path) -> Iterable[Path]:
"""Find immediate subdirectories in the distributions folder."""
if not distro_dir.exists():
raise FileNotFoundError(f"Distributions directory not found: {distro_dir}")
return sorted(d for d in distro_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
def process_distro(distro_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
"""Process a single distribution directory."""
progress.print(f"Processing {distro_dir.name}")
try:
# Import the module directly
module_name = f"llama_stack.distributions.{distro_dir.name}"
module = importlib.import_module(module_name)
# Get and save the distribution template
if template_func := getattr(module, "get_distribution_template", None):
distro = template_func()
yaml_output_dir = REPO_ROOT / "src" / "llama_stack" / "distributions" / distro.name
doc_output_dir = REPO_ROOT / "docs/docs/distributions" / f"{distro.distro_type}_distro"
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
distro.save_distribution(
yaml_output_dir=yaml_output_dir,
doc_output_dir=doc_output_dir,
)
else:
progress.print(f"[yellow]Warning: {distro_dir.name} has no get_distribution_template function")
except Exception as e:
progress.print(f"[red]Error processing {distro_dir.name}: {str(e)}")
raise e
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes."""
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def pre_import_distros(distro_dirs: list[Path]) -> None:
# Pre-import all distro modules to avoid deadlocks.
for distro_dir in distro_dirs:
module_name = f"llama_stack.distributions.{distro_dir.name}"
importlib.import_module(module_name)
def main():
args = parse_args()
repo_root = args.repo_root
if not repo_root.is_absolute():
repo_root = (Path.cwd() / repo_root).resolve()
set_repo_root(repo_root)
distros_dir = REPO_ROOT / "src" / "llama_stack" / "distributions"
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
) as progress:
distro_dirs = list(find_distro_dirs(distros_dir))
task = progress.add_task("Processing distribution templates...", total=len(distro_dirs))
pre_import_distros(distro_dirs)
# Create a partial function with the progress bar
process_func = partial(process_distro, progress=progress, change_tracker=change_tracker)
# Process distributions in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit all tasks and wait for completion
list(executor.map(process_func, distro_dirs))
progress.update(task, advance=len(distro_dirs))
if check_for_changes(change_tracker):
print(
"Distribution changes detected. Please commit the changes.",
file=sys.stderr,
)
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
main()