mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Get distro_codegen.py working with default deps and enabled in pre-commit hooks (#1123)
# What does this PR do? Before this change, `distro_codegen.py` would only work if the user manually installed multiple provider-specific dependencies (see #1122). Now, users can run `distro_codegen.py` without any provider-specific dependencies because we avoid importing the entire provider implementations just to get the config needed to build the provider template. Concretely, this mostly means moving the MODEL_ALIASES (and related variants) definitions to a new models.py class within the provider implementation for those providers that require additional dependencies. It also meant moving a couple of imports from top-level imports to inside `get_adapter_impl` for some providers, which follows the pattern used by multiple existing providers. To ensure we don't regress and accidentally add new imports that cause distro_codegen.py to fail, the stubbed-in pre-commit hook for distro_codegen.py was uncommented and slightly tweaked to run via `uv run python ...` to ensure it runs with only the project's default dependencies and to run automatically instead of manually. Lastly, this updates distro_codegen.py itself to keep track of paths it might have changed and to only `git diff` those specific paths when checking for changed files instead of doing a diff on the entire working tree. The latter was overly broad and would require a user have no other unstaged changes in their working tree, even if those unstaged changes were unrelated to generated code. Now it only flags uncommitted changes for paths distro_codegen.py actually writes to. Our generated code was also out-of-date, presumably because of these issues, so this commit also has some updates to the generated code purely because it was out of sync, and the pre-commit hook now enforces things to be updated. (Closes #1122) ## Test Plan I manually tested distro_codegen.py and the pre-commit hook to verify those work as expected, flagging any uncommited changes and catching any imports that attempt to pull in provider-specific dependencies. However, I do not have valid api keys to the impacted provider implementations, and am unable to easily run the inference tests against each changed provider. There are no functional changes to the provider implementations here, but I'd appreciate a second set of eyes on the changed import statements and moving of MODEL_ALIASES type code to a separate models.py to ensure I didn't make any obvious errors. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
9e03df983e
commit
e9b8259cf9
28 changed files with 334 additions and 248 deletions
|
@ -23,6 +23,22 @@ from llama_stack.distribution.build import (
|
|||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ChangedPathTracker:
|
||||
"""Track a list of paths we may have changed."""
|
||||
|
||||
def __init__(self):
|
||||
self._changed_paths = []
|
||||
|
||||
def add_paths(self, *paths):
|
||||
for path in paths:
|
||||
path = str(path)
|
||||
if path not in self._changed_paths:
|
||||
self._changed_paths.append(path)
|
||||
|
||||
def changed_paths(self):
|
||||
return self._changed_paths
|
||||
|
||||
|
||||
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
||||
"""Find immediate subdirectories in the templates folder."""
|
||||
if not templates_dir.exists():
|
||||
|
@ -31,7 +47,7 @@ def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
|||
return sorted(d for d in templates_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
|
||||
|
||||
|
||||
def process_template(template_dir: Path, progress) -> None:
|
||||
def process_template(template_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
|
||||
"""Process a single template directory."""
|
||||
progress.print(f"Processing {template_dir.name}")
|
||||
|
||||
|
@ -44,9 +60,12 @@ def process_template(template_dir: Path, progress) -> None:
|
|||
if template_func := getattr(module, "get_distribution_template", None):
|
||||
template = template_func()
|
||||
|
||||
yaml_output_dir = REPO_ROOT / "llama_stack" / "templates" / template.name
|
||||
doc_output_dir = REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro"
|
||||
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
|
||||
template.save_distribution(
|
||||
yaml_output_dir=REPO_ROOT / "llama_stack" / "templates" / template.name,
|
||||
doc_output_dir=REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro",
|
||||
yaml_output_dir=yaml_output_dir,
|
||||
doc_output_dir=doc_output_dir,
|
||||
)
|
||||
else:
|
||||
progress.print(f"[yellow]Warning: {template_dir.name} has no get_distribution_template function")
|
||||
|
@ -56,14 +75,19 @@ def process_template(template_dir: Path, progress) -> None:
|
|||
raise e
|
||||
|
||||
|
||||
def check_for_changes() -> bool:
|
||||
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
|
||||
"""Check if there are any uncommitted changes."""
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--exit-code"],
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
)
|
||||
return result.returncode != 0
|
||||
has_changes = False
|
||||
for path in change_tracker.changed_paths():
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--exit-code", path],
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"Change detected in '{path}'.", file=sys.stderr)
|
||||
has_changes = True
|
||||
return has_changes
|
||||
|
||||
|
||||
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
||||
|
@ -83,7 +107,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
|
|||
return None, []
|
||||
|
||||
|
||||
def generate_dependencies_file():
|
||||
def generate_dependencies_file(change_tracker: ChangedPathTracker):
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
distribution_deps = {}
|
||||
|
||||
|
@ -93,12 +117,14 @@ def generate_dependencies_file():
|
|||
distribution_deps[name] = deps
|
||||
|
||||
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
|
||||
change_tracker.add_paths(deps_file)
|
||||
with open(deps_file, "w") as f:
|
||||
f.write(json.dumps(distribution_deps, indent=2) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||
change_tracker = ChangedPathTracker()
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
|
@ -108,7 +134,7 @@ def main():
|
|||
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
|
||||
|
||||
# Create a partial function with the progress bar
|
||||
process_func = partial(process_template, progress=progress)
|
||||
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)
|
||||
|
||||
# Process templates in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
@ -116,9 +142,9 @@ def main():
|
|||
list(executor.map(process_func, template_dirs))
|
||||
progress.update(task, advance=len(template_dirs))
|
||||
|
||||
generate_dependencies_file()
|
||||
generate_dependencies_file(change_tracker)
|
||||
|
||||
if check_for_changes():
|
||||
if check_for_changes(change_tracker):
|
||||
print(
|
||||
"Distribution template changes detected. Please commit the changes.",
|
||||
file=sys.stderr,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue