From 66881b9cca0ed2bb980be420158c43d9c7c11fd7 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 8 Jul 2025 16:58:13 +0200 Subject: [PATCH] chore(api): add mypy coverage to cli Signed-off-by: Mustafa Elbehery --- llama_stack/cli/download.py | 53 +++++++++++++++++++++++-------------- pyproject.toml | 2 +- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 30b6e11e9..f1825bd8d 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -22,6 +22,7 @@ from rich.progress import ( BarColumn, DownloadColumn, Progress, + TaskID, TextColumn, TimeRemainingColumn, TransferSpeedColumn, @@ -102,7 +103,7 @@ class DownloadTask: output_file: str total_size: int = 0 downloaded_size: int = 0 - task_id: int | None = None + task_id: TaskID | None = None retries: int = 0 max_retries: int = 3 @@ -139,13 +140,13 @@ class ParallelDownloader: console=self.console, expand=True, ) - self.client_options = { + self.client_options: dict[str, object] = { "timeout": httpx.Timeout(timeout), "follow_redirects": True, } async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs): - last_exception = None + last_exception: Exception | None = None for attempt in range(task.max_retries): try: return await func(*args, **kwargs) @@ -159,15 +160,18 @@ class ParallelDownloader: ) await asyncio.sleep(wait_time) continue - raise last_exception + if last_exception is not None: + raise last_exception + raise RuntimeError("Retry failed without capturing exception") async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None: if task.total_size > 0: - self.progress.update(task.task_id, total=task.total_size) + if task.task_id is not None: + self.progress.update(task.task_id, total=task.total_size) return async def _get_info(): - response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options) + response = await client.head(task.url, headers={"Accept-Encoding": "identity"}) response.raise_for_status() return response @@ -199,7 +203,7 @@ class ParallelDownloader: async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None: async def _download_chunk(): headers = {"Range": f"bytes={start}-{end}"} - async with client.stream("GET", task.url, headers=headers, **self.client_options) as response: + async with client.stream("GET", task.url, headers=headers) as response: response.raise_for_status() with open(task.output_file, "ab") as file: @@ -207,10 +211,11 @@ class ParallelDownloader: async for chunk in response.aiter_bytes(self.buffer_size): file.write(chunk) task.downloaded_size += len(chunk) - self.progress.update( - task.task_id, - completed=task.downloaded_size, - ) + if task.task_id is not None: + self.progress.update( + task.task_id, + completed=task.downloaded_size, + ) try: await self.retry_with_exponential_backoff(task, _download_chunk) @@ -228,14 +233,21 @@ class ParallelDownloader: async def download_file(self, task: DownloadTask) -> None: try: - async with httpx.AsyncClient(**self.client_options) as client: + client_timeout = self.client_options["timeout"] + if not isinstance(client_timeout, httpx.Timeout): + raise TypeError(f"Expected httpx.Timeout, got {type(client_timeout)}") + async with httpx.AsyncClient( + timeout=client_timeout, + follow_redirects=bool(self.client_options["follow_redirects"]), + ) as client: await self.get_file_info(client, task) # Check if file is already downloaded if os.path.exists(task.output_file): if self.verify_file_integrity(task): self.console.print(f"[green]Already downloaded {task.output_file}[/green]") - self.progress.update(task.task_id, completed=task.total_size) + if task.task_id is not None: + self.progress.update(task.task_id, completed=task.total_size) return await self.prepare_download(task) @@ -259,7 +271,8 @@ class ParallelDownloader: raise DownloadError(f"Download failed: {str(e)}") from e except Exception as e: - self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]") + if task.task_id is not None: + self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]") raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e def has_disk_space(self, tasks: list[DownloadTask]) -> bool: @@ -349,7 +362,7 @@ def _hf_download( except RepositoryNotFoundError: parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.") except Exception as e: - parser.error(e) + parser.error(str(e)) print(f"\nSuccessfully downloaded model to {true_output_dir}") @@ -465,13 +478,13 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): prompt_guard_model_sku_map, ) - prompt_guard_model_sku_map = prompt_guard_model_sku_map() - prompt_guard_download_info_map = prompt_guard_download_info_map() + prompt_guard_model_sku_map_dict = prompt_guard_model_sku_map() + prompt_guard_download_info_map_dict = prompt_guard_download_info_map() for model_id in model_ids: - if model_id in prompt_guard_model_sku_map.keys(): - model = prompt_guard_model_sku_map[model_id] - info = prompt_guard_download_info_map[model_id] + if model_id in prompt_guard_model_sku_map_dict.keys(): + model = prompt_guard_model_sku_map_dict[model_id] + info = prompt_guard_download_info_map_dict[model_id] else: model = resolve_model(model_id) if model is None: diff --git a/pyproject.toml b/pyproject.toml index b41e03615..cbd5351ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,7 +225,7 @@ follow_imports = "silent" # to exclude the entire directory. exclude = [ # As we fix more and more of these, we should remove them from the list - "^llama_stack/cli/download\\.py$", + "^llama_stack/apis/common/training_types\\.py$", "^llama_stack/cli/stack/_build\\.py$", "^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/client\\.py$",