Add hacks because Cloudfront config limits on the 405b model files

This commit is contained in:
Ashwin Bharambe 2024-07-30 13:46:20 -07:00
parent 404af06e02
commit 23014ea4d1

View file

@ -10,7 +10,6 @@ import os
import shutil
import time
from pathlib import Path
from typing import Optional
import httpx
@ -24,6 +23,7 @@ from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3_1.api.sku_list import (
llama3_1_model_list,
llama_meta_folder_path,
llama_meta_pth_size,
)
from llama_toolchain.cli.subcommand import Subcommand
@ -121,18 +121,21 @@ safetensors files to avoid downloading duplicate weights.
gpus = model.hardware_requirements.gpu_count
files = [
"tokenizer.model",
"params.json",
]
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
folder_path = llama_meta_folder_path(model)
pth_size = llama_meta_pth_size(model)
# I believe we can use some concurrency here if needed but not sure it is worth it
for f in files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{folder_path}/{f}")
downloader = ResumableDownloader(url, output_file)
total_size = pth_size if "consolidated" in f else 0
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
def _run_download_cmd(self, args: argparse.Namespace):
@ -151,16 +154,25 @@ safetensors files to avoid downloading duplicate weights.
class ResumableDownloader:
def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024):
def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size: Optional[int] = None
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
return
response = await client.head(self.url, follow_redirects=True)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
@ -192,22 +204,36 @@ class ResumableDownloader:
f"Not enough disk space to download `{self.output_file}`"
)
headers = {"Range": f"bytes={self.downloaded_size}-"}
try:
async with client.stream("GET", self.url, headers=headers) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
else:
print(f"\nFinished downloading `{self.output_file}`....")
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
# print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100