mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 19:59:47 +00:00
Write a script to perform the codegen
This commit is contained in:
parent
f38e76ee98
commit
0218e68849
9 changed files with 223 additions and 142 deletions
78
llama_stack/scripts/save_distributions.py
Normal file
78
llama_stack/scripts/save_distributions.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
||||
"""Find immediate subdirectories in the templates folder."""
|
||||
if not templates_dir.exists():
|
||||
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
||||
|
||||
return (d for d in templates_dir.iterdir() if d.is_dir())
|
||||
|
||||
|
||||
def process_template(template_dir: Path, progress) -> None:
|
||||
"""Process a single template directory."""
|
||||
progress.print(f"Processing {template_dir.name}")
|
||||
|
||||
try:
|
||||
# Import the module directly
|
||||
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get and save the distribution template
|
||||
if template_func := getattr(module, "get_distribution_template", None):
|
||||
template = template_func()
|
||||
|
||||
template.save_distribution(
|
||||
yaml_output_dir=REPO_ROOT / "distributions" / template.name,
|
||||
doc_output_dir=REPO_ROOT
|
||||
/ "docs/source/getting_started/distributions"
|
||||
/ f"{template.distro_type}_distro",
|
||||
)
|
||||
else:
|
||||
progress.print(
|
||||
f"[yellow]Warning: {template_dir.name} has no get_distribution_template function"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
progress.print(f"[red]Error processing {template_dir.name}: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
) as progress:
|
||||
template_dirs = list(find_template_dirs(templates_dir))
|
||||
task = progress.add_task(
|
||||
"Processing distribution templates...", total=len(template_dirs)
|
||||
)
|
||||
|
||||
# Create a partial function with the progress bar
|
||||
process_func = partial(process_template, progress=progress)
|
||||
|
||||
# Process templates in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submit all tasks and wait for completion
|
||||
list(executor.map(process_func, template_dirs))
|
||||
progress.update(task, advance=len(template_dirs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue