chore(api): add mypy coverage to cli

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 16:58:13 +02:00
parent cd0ad21111
commit 66881b9cca
2 changed files with 34 additions and 21 deletions

View file

@ -22,6 +22,7 @@ from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TaskID,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
@ -102,7 +103,7 @@ class DownloadTask:
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: int | None = None
task_id: TaskID | None = None
retries: int = 0
max_retries: int = 3
@ -139,13 +140,13 @@ class ParallelDownloader:
console=self.console,
expand=True,
)
self.client_options = {
self.client_options: dict[str, object] = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(self, task: DownloadTask, func, *args, **kwargs):
last_exception = None
last_exception: Exception | None = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
@ -159,15 +160,18 @@ class ParallelDownloader:
)
await asyncio.sleep(wait_time)
continue
raise last_exception
if last_exception is not None:
raise last_exception
raise RuntimeError("Retry failed without capturing exception")
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
if task.total_size > 0:
self.progress.update(task.task_id, total=task.total_size)
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
return
async def _get_info():
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
response = await client.head(task.url, headers={"Accept-Encoding": "identity"})
response.raise_for_status()
return response
@ -199,7 +203,7 @@ class ParallelDownloader:
async def download_chunk(self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream("GET", task.url, headers=headers, **self.client_options) as response:
async with client.stream("GET", task.url, headers=headers) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
@ -207,10 +211,11 @@ class ParallelDownloader:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
if task.task_id is not None:
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
@ -228,14 +233,21 @@ class ParallelDownloader:
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
client_timeout = self.client_options["timeout"]
if not isinstance(client_timeout, httpx.Timeout):
raise TypeError(f"Expected httpx.Timeout, got {type(client_timeout)}")
async with httpx.AsyncClient(
timeout=client_timeout,
follow_redirects=bool(self.client_options["follow_redirects"]),
) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(f"[green]Already downloaded {task.output_file}[/green]")
self.progress.update(task.task_id, completed=task.total_size)
if task.task_id is not None:
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
@ -259,7 +271,8 @@ class ParallelDownloader:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
if task.task_id is not None:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
@ -349,7 +362,7 @@ def _hf_download(
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
except Exception as e:
parser.error(e)
parser.error(str(e))
print(f"\nSuccessfully downloaded model to {true_output_dir}")
@ -465,13 +478,13 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
prompt_guard_model_sku_map,
)
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
prompt_guard_download_info_map = prompt_guard_download_info_map()
prompt_guard_model_sku_map_dict = prompt_guard_model_sku_map()
prompt_guard_download_info_map_dict = prompt_guard_download_info_map()
for model_id in model_ids:
if model_id in prompt_guard_model_sku_map.keys():
model = prompt_guard_model_sku_map[model_id]
info = prompt_guard_download_info_map[model_id]
if model_id in prompt_guard_model_sku_map_dict.keys():
model = prompt_guard_model_sku_map_dict[model_id]
info = prompt_guard_download_info_map_dict[model_id]
else:
model = resolve_model(model_id)
if model is None:

View file

@ -225,7 +225,7 @@ follow_imports = "silent"
# to exclude the entire directory.
exclude = [
# As we fix more and more of these, we should remove them from the list
"^llama_stack/cli/download\\.py$",
"^llama_stack/apis/common/training_types\\.py$",
"^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/distribution/build\\.py$",
"^llama_stack/distribution/client\\.py$",