diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index 2a1c79220..f7365b7b4 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -6,18 +6,22 @@ import argparse import asyncio +import json import os import shutil import time +from datetime import datetime from functools import partial from pathlib import Path +from typing import Dict, List import httpx - -from llama_toolchain.cli.subcommand import Subcommand +from pydantic import BaseModel from termcolor import cprint +from llama_toolchain.cli.subcommand import Subcommand + class Download(Subcommand): """Llama cli for downloading llama toolchain assets""" @@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--model-id", choices=[x.descriptor() for x in models], - required=True, + required=False, ) parser.add_argument( "--hf-token", @@ -88,7 +92,7 @@ def _hf_download( if repo_id is None: raise ValueError(f"No repo id found for model {model.descriptor()}") - output_dir = model_local_dir(model) + output_dir = model_local_dir(model.descriptor()) os.makedirs(output_dir, exist_ok=True) try: true_output_dir = snapshot_download( @@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str): from llama_toolchain.common.model_utils import model_local_dir - output_dir = Path(model_local_dir(model)) + output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) info = llama_meta_net_info(model) @@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str): def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): from llama_models.sku_list import resolve_model + if args.manifest_file: + _download_from_manifest(args.manifest_file) + return + + if args.model_id is None: + parser.error("Please provide a model id") + return + model = resolve_model(args.model_id) if model is None: parser.error(f"Model {args.model_id} not found") @@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): _meta_download(model, meta_url) +class ModelEntry(BaseModel): + model_id: str + files: Dict[str, str] + + +class Manifest(BaseModel): + models: List[ModelEntry] + expires_on: datetime + + +def _download_from_manifest(manifest_file: str): + from llama_toolchain.common.model_utils import model_local_dir + + with open(manifest_file, "r") as f: + d = json.load(f) + manifest = Manifest(**d) + + if datetime.now() > manifest.expires_on: + raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") + + for entry in manifest.models: + print(f"Downloading model {entry.model_id}...") + output_dir = Path(model_local_dir(entry.model_id)) + os.makedirs(output_dir, exist_ok=True) + + if any(output_dir.iterdir()): + cprint(f"Output directory {output_dir} is not empty.", "red") + + while True: + resp = input( + "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): " + ) + if resp.lower() == "restart" or resp.lower() == "r": + shutil.rmtree(output_dir) + os.makedirs(output_dir, exist_ok=True) + break + elif resp.lower() == "continue" or resp.lower() == "c": + print("Continuing download...") + break + else: + cprint("Invalid response. Please try again.", "red") + + for fname, url in entry.files.items(): + output_file = str(output_dir / fname) + downloader = ResumableDownloader(url, output_file) + asyncio.run(downloader.download()) + + class ResumableDownloader: def __init__( self, @@ -190,7 +250,7 @@ class ResumableDownloader: async def download(self) -> None: self.start_time = time.time() - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: await self.get_file_info(client) if os.path.exists(self.output_file): @@ -222,7 +282,7 @@ class ResumableDownloader: headers = { "Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}" } - # print(f"Downloading `{self.output_file}`....{headers}") + print(f"Downloading `{self.output_file}`....{headers}") try: async with client.stream( "GET", self.url, headers=headers diff --git a/llama_toolchain/common/model_utils.py b/llama_toolchain/common/model_utils.py index 282e02ea8..9e0c3f034 100644 --- a/llama_toolchain/common/model_utils.py +++ b/llama_toolchain/common/model_utils.py @@ -1,9 +1,13 @@ -import os +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. -from llama_models.datatypes import Model +import os from .config_dirs import DEFAULT_CHECKPOINT_DIR -def model_local_dir(model: Model) -> str: - return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor()) +def model_local_dir(descriptor: str) -> str: + return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor) diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index dfbaf1a3e..f4d3c210b 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.sku_list import resolve_model +from termcolor import cprint from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.inference.api import QuantizationType -from termcolor import cprint from .config import MetaReferenceImplConfig def model_checkpoint_dir(model) -> str: - checkpoint_dir = Path(model_local_dir(model)) + checkpoint_dir = Path(model_local_dir(model.descriptor())) if not Path(checkpoint_dir / "consolidated.00.pth").exists(): checkpoint_dir = checkpoint_dir / "original" diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index 426376c2d..c669eed2f 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec] def resolve_and_get_path(model_name: str) -> str: model = resolve_model(model_name) assert model is not None, f"Could not resolve model {model_name}" - model_dir = model_local_dir(model) + model_dir = model_local_dir(model.descriptor()) return model_dir