mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Add hacks because Cloudfront config limits on the 405b model files
This commit is contained in:
parent
404af06e02
commit
23014ea4d1
1 changed files with 47 additions and 21 deletions
|
@ -10,7 +10,6 @@ import os
|
|||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -24,6 +23,7 @@ from llama_models.llama3_1.api.datatypes import (
|
|||
from llama_models.llama3_1.api.sku_list import (
|
||||
llama3_1_model_list,
|
||||
llama_meta_folder_path,
|
||||
llama_meta_pth_size,
|
||||
)
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
@ -121,18 +121,21 @@ safetensors files to avoid downloading duplicate weights.
|
|||
gpus = model.hardware_requirements.gpu_count
|
||||
files = [
|
||||
"tokenizer.model",
|
||||
"params.json",
|
||||
]
|
||||
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
|
||||
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
|
||||
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
|
||||
|
||||
folder_path = llama_meta_folder_path(model)
|
||||
pth_size = llama_meta_pth_size(model)
|
||||
|
||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||
for f in files:
|
||||
output_file = str(output_dir / f)
|
||||
url = meta_url.replace("*", f"{folder_path}/{f}")
|
||||
downloader = ResumableDownloader(url, output_file)
|
||||
total_size = pth_size if "consolidated" in f else 0
|
||||
downloader = ResumableDownloader(url, output_file, total_size)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
def _run_download_cmd(self, args: argparse.Namespace):
|
||||
|
@ -151,16 +154,25 @@ safetensors files to avoid downloading duplicate weights.
|
|||
|
||||
|
||||
class ResumableDownloader:
|
||||
def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
output_file: str,
|
||||
total_size: int = 0,
|
||||
buffer_size: int = 32 * 1024,
|
||||
):
|
||||
self.url = url
|
||||
self.output_file = output_file
|
||||
self.buffer_size = buffer_size
|
||||
self.total_size: Optional[int] = None
|
||||
self.total_size = total_size
|
||||
self.downloaded_size = 0
|
||||
self.start_size = 0
|
||||
self.start_time = 0
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
||||
if self.total_size > 0:
|
||||
return
|
||||
|
||||
response = await client.head(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self.url = str(response.url) # Update URL in case of redirects
|
||||
|
@ -192,22 +204,36 @@ class ResumableDownloader:
|
|||
f"Not enough disk space to download `{self.output_file}`"
|
||||
)
|
||||
|
||||
headers = {"Range": f"bytes={self.downloaded_size}-"}
|
||||
try:
|
||||
async with client.stream("GET", self.url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
with open(self.output_file, "ab") as file:
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
self.downloaded_size += len(chunk)
|
||||
self.print_progress()
|
||||
except httpx.HTTPError as e:
|
||||
print(f"\nDownload interrupted: {e}")
|
||||
print("You can resume the download by running the script again.")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
else:
|
||||
print(f"\nFinished downloading `{self.output_file}`....")
|
||||
while True:
|
||||
if self.downloaded_size >= self.total_size:
|
||||
break
|
||||
|
||||
# Cloudfront has a max-size limit
|
||||
max_chunk_size = 27_000_000_000
|
||||
request_size = min(
|
||||
self.total_size - self.downloaded_size, max_chunk_size
|
||||
)
|
||||
headers = {
|
||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||
}
|
||||
# print(f"Downloading `{self.output_file}`....{headers}")
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", self.url, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
with open(self.output_file, "ab") as file:
|
||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||
file.write(chunk)
|
||||
self.downloaded_size += len(chunk)
|
||||
self.print_progress()
|
||||
except httpx.HTTPError as e:
|
||||
print(f"\nDownload interrupted: {e}")
|
||||
print("You can resume the download by running the script again.")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
|
||||
print(f"\nFinished downloading `{self.output_file}`....")
|
||||
|
||||
def print_progress(self) -> None:
|
||||
percent = (self.downloaded_size / self.total_size) * 100
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue