generator fix

This commit is contained in:
Alexey Rybak 2025-09-22 17:16:53 -07:00
parent 07b0ee0c2f
commit 584f3592ce
4 changed files with 294 additions and 40 deletions

View file

@ -39,14 +39,10 @@ def find_distro_dirs(distro_dir: Path) -> Iterable[Path]:
if not distro_dir.exists():
raise FileNotFoundError(f"Distributions directory not found: {distro_dir}")
return sorted(
d for d in distro_dir.iterdir() if d.is_dir() and d.name != "__pycache__"
)
return sorted(d for d in distro_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
def process_distro(
distro_dir: Path, progress, change_tracker: ChangedPathTracker
) -> None:
def process_distro(distro_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
"""Process a single distribution directory."""
progress.print(f"Processing {distro_dir.name}")
@ -60,18 +56,14 @@ def process_distro(
distro = template_func()
yaml_output_dir = REPO_ROOT / "llama_stack" / "distributions" / distro.name
doc_output_dir = (
REPO_ROOT / "docs/docs/distributions" / f"{distro.distro_type}_distro"
)
doc_output_dir = REPO_ROOT / "docs/docs/distributions" / f"{distro.distro_type}_distro"
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
distro.save_distribution(
yaml_output_dir=yaml_output_dir,
doc_output_dir=doc_output_dir,
)
else:
progress.print(
f"[yellow]Warning: {distro_dir.name} has no get_distribution_template function"
)
progress.print(f"[yellow]Warning: {distro_dir.name} has no get_distribution_template function")
except Exception as e:
progress.print(f"[red]Error processing {distro_dir.name}: {str(e)}")
@ -109,16 +101,12 @@ def main():
TextColumn("[progress.description]{task.description}"),
) as progress:
distro_dirs = list(find_distro_dirs(distros_dir))
task = progress.add_task(
"Processing distribution templates...", total=len(distro_dirs)
)
task = progress.add_task("Processing distribution templates...", total=len(distro_dirs))
pre_import_distros(distro_dirs)
# Create a partial function with the progress bar
process_func = partial(
process_distro, progress=progress, change_tracker=change_tracker
)
process_func = partial(process_distro, progress=progress, change_tracker=change_tracker)
# Process distributions in parallel
with concurrent.futures.ThreadPoolExecutor() as executor: