Add a --manifest-file option to llama download

This commit is contained in:
Ashwin Bharambe 2024-08-17 10:08:00 -07:00
parent b8fc4d4dee
commit 5e072d0780
4 changed files with 78 additions and 14 deletions

View file

@ -6,18 +6,22 @@
import argparse
import asyncio
import json
import os
import shutil
import time
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
import httpx
from llama_toolchain.cli.subcommand import Subcommand
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
choices=[x.descriptor() for x in models],
required=True,
required=False,
)
parser.add_argument(
"--hf-token",
@ -88,7 +92,7 @@ def _hf_download(
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model)
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str):
from llama_toolchain.common.model_utils import model_local_dir
output_dir = Path(model_local_dir(model))
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)
@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str):
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
_meta_download(model, meta_url)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Manifest(BaseModel):
models: List[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str):
from llama_toolchain.common.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
break
else:
cprint("Invalid response. Please try again.", "red")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
class ResumableDownloader:
def __init__(
self,
@ -190,7 +250,7 @@ class ResumableDownloader:
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
if os.path.exists(self.output_file):
@ -222,7 +282,7 @@ class ResumableDownloader:
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
# print(f"Downloading `{self.output_file}`....{headers}")
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers

View file

@ -1,9 +1,13 @@
import os
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.datatypes import Model
import os
from .config_dirs import DEFAULT_CHECKPOINT_DIR
def model_local_dir(model: Model) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
def model_local_dir(descriptor: str) -> str:
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)

View file

@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model))
checkpoint_dir = Path(model_local_dir(model.descriptor()))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original"

View file

@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model)
model_dir = model_local_dir(model.descriptor())
return model_dir