fix: Get distro_codegen.py working with default deps and enabled in pre-commit hooks (#1123)

# What does this PR do?

Before this change, `distro_codegen.py` would only work if the user
manually installed multiple provider-specific dependencies (see #1122).
Now, users can run `distro_codegen.py` without any provider-specific
dependencies because we avoid importing the entire provider
implementations just to get the config needed to build the provider
template.

Concretely, this mostly means moving the
MODEL_ALIASES (and related variants) definitions to a new models.py
class within the provider implementation for those providers that
require additional dependencies. It also meant moving a couple of
imports from top-level imports to inside `get_adapter_impl` for some
providers, which follows the pattern used by multiple existing
providers.

To ensure we don't regress and accidentally add new imports that cause
distro_codegen.py to fail, the stubbed-in pre-commit hook for
distro_codegen.py was uncommented and slightly tweaked to run via `uv
run python ...` to ensure it runs with only the project's default
dependencies and to run automatically instead of manually.

Lastly, this updates distro_codegen.py itself to keep track of paths it
might have changed and to only `git diff` those specific paths when
checking for changed files instead of doing a diff on the entire working
tree. The latter was overly broad and would require a user have no other
unstaged changes in their working tree, even if those unstaged changes
were unrelated to generated code. Now it only flags uncommitted changes
for paths distro_codegen.py actually writes to.

Our generated code was also out-of-date, presumably because of these
issues, so this commit also has some updates to the generated code
purely because it was out of sync, and the pre-commit hook now enforces
things to be updated.

(Closes #1122)

## Test Plan

I manually tested distro_codegen.py and the pre-commit hook to verify
those work as expected, flagging any uncommited changes and catching any
imports that attempt to pull in provider-specific dependencies.

However, I do not have valid api keys to the impacted provider
implementations, and am unable to easily run the inference tests against
each changed provider. There are no functional changes to the provider
implementations here, but I'd appreciate a second set of eyes on the
changed import statements and moving of MODEL_ALIASES type code to a
separate models.py to ensure I didn't make any obvious errors.

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Ben Browning 2025-02-19 21:39:20 -05:00 committed by GitHub
parent 9e03df983e
commit e9b8259cf9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 334 additions and 248 deletions

View file

@ -23,6 +23,22 @@ from llama_stack.distribution.build import (
REPO_ROOT = Path(__file__).parent.parent.parent
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
@ -31,7 +47,7 @@ def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
return sorted(d for d in templates_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
def process_template(template_dir: Path, progress) -> None:
def process_template(template_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
"""Process a single template directory."""
progress.print(f"Processing {template_dir.name}")
@ -44,9 +60,12 @@ def process_template(template_dir: Path, progress) -> None:
if template_func := getattr(module, "get_distribution_template", None):
template = template_func()
yaml_output_dir = REPO_ROOT / "llama_stack" / "templates" / template.name
doc_output_dir = REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro"
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
template.save_distribution(
yaml_output_dir=REPO_ROOT / "llama_stack" / "templates" / template.name,
doc_output_dir=REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro",
yaml_output_dir=yaml_output_dir,
doc_output_dir=doc_output_dir,
)
else:
progress.print(f"[yellow]Warning: {template_dir.name} has no get_distribution_template function")
@ -56,14 +75,19 @@ def process_template(template_dir: Path, progress) -> None:
raise e
def check_for_changes() -> bool:
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes."""
result = subprocess.run(
["git", "diff", "--exit-code"],
cwd=REPO_ROOT,
capture_output=True,
)
return result.returncode != 0
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
@ -83,7 +107,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
return None, []
def generate_dependencies_file():
def generate_dependencies_file(change_tracker: ChangedPathTracker):
templates_dir = REPO_ROOT / "llama_stack" / "templates"
distribution_deps = {}
@ -93,12 +117,14 @@ def generate_dependencies_file():
distribution_deps[name] = deps
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
change_tracker.add_paths(deps_file)
with open(deps_file, "w") as f:
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
@ -108,7 +134,7 @@ def main():
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
# Create a partial function with the progress bar
process_func = partial(process_template, progress=progress)
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)
# Process templates in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
@ -116,9 +142,9 @@ def main():
list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs))
generate_dependencies_file()
generate_dependencies_file(change_tracker)
if check_for_changes():
if check_for_changes(change_tracker):
print(
"Distribution template changes detected. Please commit the changes.",
file=sys.stderr,