From 83c6b200674b94d3e32a033398a79ba06380805e Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Thu, 10 Jul 2025 16:53:38 +0200 Subject: [PATCH] chore(api): add `mypy` coverage to `cli/stack` (#2650) # What does this PR do? This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 ## Test Plan Signed-off-by: Mustafa Elbehery --- llama_stack/cli/stack/_build.py | 24 +++++++++++++++++++----- pyproject.toml | 1 - 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 5d88b1d82..b573b2edc 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) elif args.providers: - providers = dict() + providers_list: dict[str, str | list[str]] = dict() for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) if provider in providers_for_api: - providers.setdefault(api, []).append(provider) + if api not in providers_list: + providers_list[api] = [] + # Use type guarding to ensure we have a list + provider_value = providers_list[api] + if isinstance(provider_value, list): + provider_value.append(provider) + else: + # Convert string to list and append + providers_list[api] = [provider_value, provider] else: cprint( f"{provider} is not a valid provider for the {api} API.", @@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) distribution_spec = DistributionSpec( - providers=providers, + providers=providers_list, description=",".join(args.providers), ) if not args.image_type: @@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) - providers = dict() + providers: dict[str, str | list[str]] = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] if not available_providers: @@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config( if not image_name: raise ValueError("Please specify an image name when building a venv image") + # At this point, image_name should be guaranteed to be a string + if image_name is None: + raise ValueError("image_name should not be None after validation") + if template_name: build_dir = DISTRIBS_BASE_DIR / template_name build_file_path = build_dir / f"{template_name}-build.yaml" else: + if image_name is None: + raise ValueError("image_name cannot be None") build_dir = DISTRIBS_BASE_DIR / image_name build_file_path = build_dir / f"{image_name}-build.yaml" @@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config( build_file_path, image_name, template_or_config=template_name or config_path or str(build_file_path), - run_config=run_config_file, + run_config=run_config_file.as_posix() if run_config_file else None, ) if return_code != 0: raise RuntimeError(f"Failed to build image {image_name}") diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..d84a823a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,6 @@ follow_imports = "silent" exclude = [ # As we fix more and more of these, we should remove them from the list "^llama_stack/cli/download\\.py$", - "^llama_stack/cli/stack/_build\\.py$", "^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/request_headers\\.py$",