llama-stack/llama_stack/cli/download.py
2024-09-25 10:29:58 -07:00

339 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
import shutil
import time
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_download_parser(self.parser)
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
from llama_models.sku_list import all_registered_models
models = all_registered_models()
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
default="meta",
)
parser.add_argument(
"--model-id",
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
parser.add_argument(
"--hf-token",
type=str,
required=False,
default=None,
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
)
parser.add_argument(
"--meta-url",
type=str,
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="""
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights.
""",
)
parser.add_argument(
"--manifest-file",
type=str,
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
required=False,
)
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_stack.distribution.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
repo_id,
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
library_name="llama-stack",
)
except GatedRepoError:
parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
except Exception as e:
parser.error(e)
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)
# I believe we can use some concurrency here if needed but not sure it is worth it
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
cprint(f"Downloading `{f}`...", "white")
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
meta_url = input(
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Config:
protected_namespaces = ()
class Manifest(BaseModel):
models: List[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str):
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
break
else:
cprint("Invalid response. Please try again.", "red")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
class ResumableDownloader:
def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
return
# Force disable compression when trying to retrieve file size
response = await client.head(
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
self.total_size = int(response.headers.get("Content-Length", 0))
if self.total_size == 0:
raise ValueError(
"Unable to determine file size. The server might not support range requests."
)
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
if os.path.exists(self.output_file):
self.downloaded_size = os.path.getsize(self.output_file)
self.start_size = self.downloaded_size
if self.downloaded_size >= self.total_size:
print(f"Already downloaded `{self.output_file}`, skipping...")
return
additional_size = self.total_size - self.downloaded_size
if not self.has_disk_space(additional_size):
M = 1024 * 1024 # noqa
print(
f"Not enough disk space to download `{self.output_file}`. "
f"Required: {(additional_size // M):.2f} MB"
)
raise ValueError(
f"Not enough disk space to download `{self.output_file}`"
)
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100
bar_length = 50
filled_length = int(bar_length * self.downloaded_size // self.total_size)
bar = "" * filled_length + "-" * (bar_length - filled_length)
elapsed_time = time.time() - self.start_time
M = 1024 * 1024 # noqa
speed = (
(self.downloaded_size - self.start_size) / (elapsed_time * M)
if elapsed_time > 0
else 0
)
print(
f"\rProgress: |{bar}| {percent:.2f}% "
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
f"Speed: {speed:.2f} MiB/s",
end="",
flush=True,
)
def has_disk_space(self, file_size: int) -> bool:
dir_path = os.path.dirname(os.path.abspath(self.output_file))
free_space = shutil.disk_usage(dir_path).free
return free_space > file_size