forked from phoenix-oss/llama-stack-mirror
Add hacks because Cloudfront config limits on the 405b model files
This commit is contained in:
parent
404af06e02
commit
23014ea4d1
1 changed files with 47 additions and 21 deletions
|
@ -10,7 +10,6 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -24,6 +23,7 @@ from llama_models.llama3_1.api.datatypes import (
|
||||||
from llama_models.llama3_1.api.sku_list import (
|
from llama_models.llama3_1.api.sku_list import (
|
||||||
llama3_1_model_list,
|
llama3_1_model_list,
|
||||||
llama_meta_folder_path,
|
llama_meta_folder_path,
|
||||||
|
llama_meta_pth_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
@ -121,18 +121,21 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
gpus = model.hardware_requirements.gpu_count
|
gpus = model.hardware_requirements.gpu_count
|
||||||
files = [
|
files = [
|
||||||
"tokenizer.model",
|
"tokenizer.model",
|
||||||
|
"params.json",
|
||||||
]
|
]
|
||||||
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
|
|
||||||
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||||
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
|
files.extend([f"fp8_scales_{i}.pt" for i in range(gpus)])
|
||||||
|
files.extend([f"consolidated.{i:02d}.pth" for i in range(gpus)])
|
||||||
|
|
||||||
folder_path = llama_meta_folder_path(model)
|
folder_path = llama_meta_folder_path(model)
|
||||||
|
pth_size = llama_meta_pth_size(model)
|
||||||
|
|
||||||
# I believe we can use some concurrency here if needed but not sure it is worth it
|
# I believe we can use some concurrency here if needed but not sure it is worth it
|
||||||
for f in files:
|
for f in files:
|
||||||
output_file = str(output_dir / f)
|
output_file = str(output_dir / f)
|
||||||
url = meta_url.replace("*", f"{folder_path}/{f}")
|
url = meta_url.replace("*", f"{folder_path}/{f}")
|
||||||
downloader = ResumableDownloader(url, output_file)
|
total_size = pth_size if "consolidated" in f else 0
|
||||||
|
downloader = ResumableDownloader(url, output_file, total_size)
|
||||||
asyncio.run(downloader.download())
|
asyncio.run(downloader.download())
|
||||||
|
|
||||||
def _run_download_cmd(self, args: argparse.Namespace):
|
def _run_download_cmd(self, args: argparse.Namespace):
|
||||||
|
@ -151,16 +154,25 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
|
|
||||||
|
|
||||||
class ResumableDownloader:
|
class ResumableDownloader:
|
||||||
def __init__(self, url: str, output_file: str, buffer_size: int = 32 * 1024):
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
output_file: str,
|
||||||
|
total_size: int = 0,
|
||||||
|
buffer_size: int = 32 * 1024,
|
||||||
|
):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.output_file = output_file
|
self.output_file = output_file
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.total_size: Optional[int] = None
|
self.total_size = total_size
|
||||||
self.downloaded_size = 0
|
self.downloaded_size = 0
|
||||||
self.start_size = 0
|
self.start_size = 0
|
||||||
self.start_time = 0
|
self.start_time = 0
|
||||||
|
|
||||||
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
async def get_file_info(self, client: httpx.AsyncClient) -> None:
|
||||||
|
if self.total_size > 0:
|
||||||
|
return
|
||||||
|
|
||||||
response = await client.head(self.url, follow_redirects=True)
|
response = await client.head(self.url, follow_redirects=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
self.url = str(response.url) # Update URL in case of redirects
|
self.url = str(response.url) # Update URL in case of redirects
|
||||||
|
@ -192,22 +204,36 @@ class ResumableDownloader:
|
||||||
f"Not enough disk space to download `{self.output_file}`"
|
f"Not enough disk space to download `{self.output_file}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = {"Range": f"bytes={self.downloaded_size}-"}
|
while True:
|
||||||
try:
|
if self.downloaded_size >= self.total_size:
|
||||||
async with client.stream("GET", self.url, headers=headers) as response:
|
break
|
||||||
response.raise_for_status()
|
|
||||||
with open(self.output_file, "ab") as file:
|
# Cloudfront has a max-size limit
|
||||||
async for chunk in response.aiter_bytes(self.buffer_size):
|
max_chunk_size = 27_000_000_000
|
||||||
file.write(chunk)
|
request_size = min(
|
||||||
self.downloaded_size += len(chunk)
|
self.total_size - self.downloaded_size, max_chunk_size
|
||||||
self.print_progress()
|
)
|
||||||
except httpx.HTTPError as e:
|
headers = {
|
||||||
print(f"\nDownload interrupted: {e}")
|
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||||
print("You can resume the download by running the script again.")
|
}
|
||||||
except Exception as e:
|
# print(f"Downloading `{self.output_file}`....{headers}")
|
||||||
print(f"\nAn error occurred: {e}")
|
try:
|
||||||
else:
|
async with client.stream(
|
||||||
print(f"\nFinished downloading `{self.output_file}`....")
|
"GET", self.url, headers=headers
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(self.output_file, "ab") as file:
|
||||||
|
async for chunk in response.aiter_bytes(self.buffer_size):
|
||||||
|
file.write(chunk)
|
||||||
|
self.downloaded_size += len(chunk)
|
||||||
|
self.print_progress()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
print(f"\nDownload interrupted: {e}")
|
||||||
|
print("You can resume the download by running the script again.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nAn error occurred: {e}")
|
||||||
|
|
||||||
|
print(f"\nFinished downloading `{self.output_file}`....")
|
||||||
|
|
||||||
def print_progress(self) -> None:
|
def print_progress(self) -> None:
|
||||||
percent = (self.downloaded_size / self.total_size) * 100
|
percent = (self.downloaded_size / self.total_size) * 100
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue