forked from phoenix-oss/llama-stack-mirror
Add a --manifest-file
option to llama download
This commit is contained in:
parent
b8fc4d4dee
commit
5e072d0780
4 changed files with 78 additions and 14 deletions
|
@ -6,18 +6,22 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from pydantic import BaseModel
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
"""Llama cli for downloading llama toolchain assets"""
|
||||
|
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
parser.add_argument(
|
||||
"--model-id",
|
||||
choices=[x.descriptor() for x in models],
|
||||
required=True,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-token",
|
||||
|
@ -88,7 +92,7 @@ def _hf_download(
|
|||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = model_local_dir(model)
|
||||
output_dir = model_local_dir(model.descriptor())
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
|
@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str):
|
|||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(model_local_dir(model))
|
||||
output_dir = Path(model_local_dir(model.descriptor()))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str):
|
|||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
if args.manifest_file:
|
||||
_download_from_manifest(args.manifest_file)
|
||||
return
|
||||
|
||||
if args.model_id is None:
|
||||
parser.error("Please provide a model id")
|
||||
return
|
||||
|
||||
model = resolve_model(args.model_id)
|
||||
if model is None:
|
||||
parser.error(f"Model {args.model_id} not found")
|
||||
|
@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
_meta_download(model, meta_url)
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: Dict[str, str]
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
models: List[ModelEntry]
|
||||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str):
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file, "r") as f:
|
||||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
for entry in manifest.models:
|
||||
print(f"Downloading model {entry.model_id}...")
|
||||
output_dir = Path(model_local_dir(entry.model_id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if any(output_dir.iterdir()):
|
||||
cprint(f"Output directory {output_dir} is not empty.", "red")
|
||||
|
||||
while True:
|
||||
resp = input(
|
||||
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||
)
|
||||
if resp.lower() == "restart" or resp.lower() == "r":
|
||||
shutil.rmtree(output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
break
|
||||
elif resp.lower() == "continue" or resp.lower() == "c":
|
||||
print("Continuing download...")
|
||||
break
|
||||
else:
|
||||
cprint("Invalid response. Please try again.", "red")
|
||||
|
||||
for fname, url in entry.files.items():
|
||||
output_file = str(output_dir / fname)
|
||||
downloader = ResumableDownloader(url, output_file)
|
||||
asyncio.run(downloader.download())
|
||||
|
||||
|
||||
class ResumableDownloader:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -190,7 +250,7 @@ class ResumableDownloader:
|
|||
|
||||
async def download(self) -> None:
|
||||
self.start_time = time.time()
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
await self.get_file_info(client)
|
||||
|
||||
if os.path.exists(self.output_file):
|
||||
|
@ -222,7 +282,7 @@ class ResumableDownloader:
|
|||
headers = {
|
||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||
}
|
||||
# print(f"Downloading `{self.output_file}`....{headers}")
|
||||
print(f"Downloading `{self.output_file}`....{headers}")
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET", self.url, headers=headers
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import os
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
import os
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(model: Model) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||
|
|
|
@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message
|
|||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model))
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
|
|||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model)
|
||||
model_dir = model_local_dir(model.descriptor())
|
||||
return model_dir
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue