Add a --manifest-file option to llama download

This commit is contained in:
Ashwin Bharambe 2024-08-17 10:08:00 -07:00
parent b8fc4d4dee
commit 5e072d0780
4 changed files with 78 additions and 14 deletions

View file

@ -6,18 +6,22 @@
import argparse import argparse
import asyncio import asyncio
import json
import os import os
import shutil import shutil
import time import time
from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List
import httpx import httpx
from pydantic import BaseModel
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
class Download(Subcommand): class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets""" """Llama cli for downloading llama toolchain assets"""
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
"--model-id", "--model-id",
choices=[x.descriptor() for x in models], choices=[x.descriptor() for x in models],
required=True, required=False,
) )
parser.add_argument( parser.add_argument(
"--hf-token", "--hf-token",
@ -88,7 +92,7 @@ def _hf_download(
if repo_id is None: if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}") raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model) output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
try: try:
true_output_dir = snapshot_download( true_output_dir = snapshot_download(
@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str):
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
output_dir = Path(model_local_dir(model)) output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model) info = llama_meta_net_info(model)
@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id) model = resolve_model(args.model_id)
if model is None: if model is None:
parser.error(f"Model {args.model_id} not found") parser.error(f"Model {args.model_id} not found")
@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
_meta_download(model, meta_url) _meta_download(model, meta_url)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Manifest(BaseModel):
models: List[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str):
from llama_toolchain.common.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
break
else:
cprint("Invalid response. Please try again.", "red")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
class ResumableDownloader: class ResumableDownloader:
def __init__( def __init__(
self, self,
@ -190,7 +250,7 @@ class ResumableDownloader:
async def download(self) -> None: async def download(self) -> None:
self.start_time = time.time() self.start_time = time.time()
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client) await self.get_file_info(client)
if os.path.exists(self.output_file): if os.path.exists(self.output_file):
@ -222,7 +282,7 @@ class ResumableDownloader:
headers = { headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
} }
# print(f"Downloading `{self.output_file}`....{headers}") print(f"Downloading `{self.output_file}`....{headers}")
try: try:
async with client.stream( async with client.stream(
"GET", self.url, headers=headers "GET", self.url, headers=headers

View file

@ -1,9 +1,13 @@
import os # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.datatypes import Model import os
from .config_dirs import DEFAULT_CHECKPOINT_DIR from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(model: Model) -> str: def model_local_dir(descriptor: str) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor()) return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)

View file

@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model)) checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists(): if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"

View file

@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
def resolve_and_get_path(model_name: str) -> str: def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name) model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}" assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model) model_dir = model_local_dir(model.descriptor())
return model_dir return model_dir