From 23014ea4d13a1e42a4676c93d35eea915d6693fb Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 30 Jul 2024 13:46:20 -0700 Subject: [PATCH] Add hacks because Cloudfront config limits on the 405b model files --- llama_toolchain/cli/download.py | 68 +++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index ee774c67d..63452a311 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -10,7 +10,6 @@ import os import shutil import time from pathlib import Path -from typing import Optional import httpx @@ -24,6 +23,7 @@ from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3_1.api.sku_list import ( llama3_1_model_list, llama_meta_folder_path, + llama_meta_pth_size, ) from llama_toolchain.cli.subcommand import Subcommand @@ -121,18 +121,21 @@ safetensors files to avoid downloading duplicate weights. gpus = model.hardware_requirements.gpu_count files = [ "tokenizer.model", + "params.json", ] - files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)]) if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)]) + files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)]) folder_path = llama_meta_folder_path(model) + pth_size = llama_meta_pth_size(model) # I believe we can use some concurrency here if needed but not sure it is worth it for f in files: output_file = str(output_dir / f) url = meta_url.replace("*", f"{folder_path}/{f}") - downloader = ResumableDownloader(url, output_file) + total_size = pth_size if "consolidated" in f else 0 + downloader = ResumableDownloader(url, output_file, total_size) asyncio.run(downloader.download()) def _run_download_cmd(self, args: argparse.Namespace): @@ -151,16 +154,25 @@ safetensors files to avoid downloading duplicate weights. class ResumableDownloader: - def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024): + def __init__( + self, + url: str, + output_file: str, + total_size: int = 0, + buffer_size: int = 32 * 1024, + ): self.url = url self.output_file = output_file self.buffer_size = buffer_size - self.total_size: Optional[int] = None + self.total_size = total_size self.downloaded_size = 0 self.start_size = 0 self.start_time = 0 async def get_file_info(self, client: httpx.AsyncClient) -> None: + if self.total_size > 0: + return + response = await client.head(self.url, follow_redirects=True) response.raise_for_status() self.url = str(response.url) # Update URL in case of redirects @@ -192,22 +204,36 @@ class ResumableDownloader: f"Not enough disk space to download `{self.output_file}`" ) - headers = {"Range": f"bytes={self.downloaded_size}-"} - try: - async with client.stream("GET", self.url, headers=headers) as response: - response.raise_for_status() - with open(self.output_file, "ab") as file: - async for chunk in response.aiter_bytes(self.buffer_size): - file.write(chunk) - self.downloaded_size += len(chunk) - self.print_progress() - except httpx.HTTPError as e: - print(f"\nDownload interrupted: {e}") - print("You can resume the download by running the script again.") - except Exception as e: - print(f"\nAn error occurred: {e}") - else: - print(f"\nFinished downloading `{self.output_file}`....") + while True: + if self.downloaded_size >= self.total_size: + break + + # Cloudfront has a max-size limit + max_chunk_size = 27_000_000_000 + request_size = min( + self.total_size - self.downloaded_size, max_chunk_size + ) + headers = { + "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" + } + # print(f"Downloading `{self.output_file}`....{headers}") + try: + async with client.stream( + "GET", self.url, headers=headers + ) as response: + response.raise_for_status() + with open(self.output_file, "ab") as file: + async for chunk in response.aiter_bytes(self.buffer_size): + file.write(chunk) + self.downloaded_size += len(chunk) + self.print_progress() + except httpx.HTTPError as e: + print(f"\nDownload interrupted: {e}") + print("You can resume the download by running the script again.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + print(f"\nFinished downloading `{self.output_file}`....") def print_progress(self) -> None: percent = (self.downloaded_size / self.total_size) * 100