Add hacks because Cloudfront config limits on the 405b model files

This commit is contained in:
Ashwin Bharambe 2024-07-30 13:46:20 -07:00
parent 404af06e02
commit 23014ea4d1

View file

@ -10,7 +10,6 @@ import os
import shutil import shutil
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional
import httpx import httpx
@ -24,6 +23,7 @@ from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3_1.api.sku_list import ( from llama_models.llama3_1.api.sku_list import (
llama3_1_model_list, llama3_1_model_list,
llama_meta_folder_path, llama_meta_folder_path,
llama_meta_pth_size,
) )
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
@ -121,18 +121,21 @@ safetensors files to avoid downloading duplicate weights.
gpus = model.hardware_requirements.gpu_count gpus = model.hardware_requirements.gpu_count
files = [ files = [
"tokenizer.model", "tokenizer.model",
"params.json",
] ]
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)]) files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
folder_path = llama_meta_folder_path(model) folder_path = llama_meta_folder_path(model)
pth_size = llama_meta_pth_size(model)
# I believe we can use some concurrency here if needed but not sure it is worth it # I believe we can use some concurrency here if needed but not sure it is worth it
for f in files: for f in files:
output_file = str(output_dir / f) output_file = str(output_dir / f)
url = meta_url.replace("*", f"{folder_path}/{f}") url = meta_url.replace("*", f"{folder_path}/{f}")
downloader = ResumableDownloader(url, output_file) total_size = pth_size if "consolidated" in f else 0
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download()) asyncio.run(downloader.download())
def _run_download_cmd(self, args: argparse.Namespace): def _run_download_cmd(self, args: argparse.Namespace):
@ -151,16 +154,25 @@ safetensors files to avoid downloading duplicate weights.
class ResumableDownloader: class ResumableDownloader:
def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024): def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url self.url = url
self.output_file = output_file self.output_file = output_file
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.total_size: Optional[int] = None self.total_size = total_size
self.downloaded_size = 0 self.downloaded_size = 0
self.start_size = 0 self.start_size = 0
self.start_time = 0 self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None: async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
return
response = await client.head(self.url, follow_redirects=True) response = await client.head(self.url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects self.url = str(response.url) # Update URL in case of redirects
@ -192,9 +204,23 @@ class ResumableDownloader:
f"Not enough disk space to download `{self.output_file}`" f"Not enough disk space to download `{self.output_file}`"
) )
headers = {"Range": f"bytes={self.downloaded_size}-"} while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
# print(f"Downloading `{self.output_file}`....{headers}")
try: try:
async with client.stream("GET", self.url, headers=headers) as response: async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status() response.raise_for_status()
with open(self.output_file, "ab") as file: with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size): async for chunk in response.aiter_bytes(self.buffer_size):
@ -206,7 +232,7 @@ class ResumableDownloader:
print("You can resume the download by running the script again.") print("You can resume the download by running the script again.")
except Exception as e: except Exception as e:
print(f"\nAn error occurred: {e}") print(f"\nAn error occurred: {e}")
else:
print(f"\nFinished downloading `{self.output_file}`....") print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None: def print_progress(self) -> None: