forked from phoenix-oss/llama-stack-mirror
Add a --manifest-file
option to llama download
This commit is contained in:
parent
b8fc4d4dee
commit
5e072d0780
4 changed files with 78 additions and 14 deletions
|
@ -6,18 +6,22 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
"""Llama cli for downloading llama toolchain assets"""
|
"""Llama cli for downloading llama toolchain assets"""
|
||||||
|
@ -45,7 +49,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-id",
|
"--model-id",
|
||||||
choices=[x.descriptor() for x in models],
|
choices=[x.descriptor() for x in models],
|
||||||
required=True,
|
required=False,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-token",
|
"--hf-token",
|
||||||
|
@ -88,7 +92,7 @@ def _hf_download(
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||||
|
|
||||||
output_dir = model_local_dir(model)
|
output_dir = model_local_dir(model.descriptor())
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
true_output_dir = snapshot_download(
|
true_output_dir = snapshot_download(
|
||||||
|
@ -118,7 +122,7 @@ def _meta_download(model: "Model", meta_url: str):
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
output_dir = Path(model_local_dir(model))
|
output_dir = Path(model_local_dir(model.descriptor()))
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
info = llama_meta_net_info(model)
|
info = llama_meta_net_info(model)
|
||||||
|
@ -139,6 +143,14 @@ def _meta_download(model: "Model", meta_url: str):
|
||||||
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
if args.manifest_file:
|
||||||
|
_download_from_manifest(args.manifest_file)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.model_id is None:
|
||||||
|
parser.error("Please provide a model id")
|
||||||
|
return
|
||||||
|
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
parser.error(f"Model {args.model_id} not found")
|
parser.error(f"Model {args.model_id} not found")
|
||||||
|
@ -156,6 +168,54 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
_meta_download(model, meta_url)
|
_meta_download(model, meta_url)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEntry(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
files: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class Manifest(BaseModel):
|
||||||
|
models: List[ModelEntry]
|
||||||
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
|
def _download_from_manifest(manifest_file: str):
|
||||||
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
|
|
||||||
|
with open(manifest_file, "r") as f:
|
||||||
|
d = json.load(f)
|
||||||
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
|
if datetime.now() > manifest.expires_on:
|
||||||
|
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||||
|
|
||||||
|
for entry in manifest.models:
|
||||||
|
print(f"Downloading model {entry.model_id}...")
|
||||||
|
output_dir = Path(model_local_dir(entry.model_id))
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if any(output_dir.iterdir()):
|
||||||
|
cprint(f"Output directory {output_dir} is not empty.", "red")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
resp = input(
|
||||||
|
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
|
||||||
|
)
|
||||||
|
if resp.lower() == "restart" or resp.lower() == "r":
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
break
|
||||||
|
elif resp.lower() == "continue" or resp.lower() == "c":
|
||||||
|
print("Continuing download...")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
cprint("Invalid response. Please try again.", "red")
|
||||||
|
|
||||||
|
for fname, url in entry.files.items():
|
||||||
|
output_file = str(output_dir / fname)
|
||||||
|
downloader = ResumableDownloader(url, output_file)
|
||||||
|
asyncio.run(downloader.download())
|
||||||
|
|
||||||
|
|
||||||
class ResumableDownloader:
|
class ResumableDownloader:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -190,7 +250,7 @@ class ResumableDownloader:
|
||||||
|
|
||||||
async def download(self) -> None:
|
async def download(self) -> None:
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
await self.get_file_info(client)
|
await self.get_file_info(client)
|
||||||
|
|
||||||
if os.path.exists(self.output_file):
|
if os.path.exists(self.output_file):
|
||||||
|
@ -222,7 +282,7 @@ class ResumableDownloader:
|
||||||
headers = {
|
headers = {
|
||||||
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
|
||||||
}
|
}
|
||||||
# print(f"Downloading `{self.output_file}`....{headers}")
|
print(f"Downloading `{self.output_file}`....{headers}")
|
||||||
try:
|
try:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"GET", self.url, headers=headers
|
"GET", self.url, headers=headers
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
import os
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.datatypes import Model
|
import os
|
||||||
|
|
||||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
|
||||||
def model_local_dir(model: Model) -> str:
|
def model_local_dir(descriptor: str) -> str:
|
||||||
return os.path.join(DEFAULT_CHECKPOINT_DIR, model.descriptor())
|
return os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||||
|
|
|
@ -28,16 +28,16 @@ from llama_models.llama3_1.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
from llama_models.llama3_1.reference_impl.model import Transformer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.common.model_utils import model_local_dir
|
from llama_toolchain.common.model_utils import model_local_dir
|
||||||
from llama_toolchain.inference.api import QuantizationType
|
from llama_toolchain.inference.api import QuantizationType
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
|
||||||
def resolve_and_get_path(model_name: str) -> str:
|
def resolve_and_get_path(model_name: str) -> str:
|
||||||
model = resolve_model(model_name)
|
model = resolve_model(model_name)
|
||||||
assert model is not None, f"Could not resolve model {model_name}"
|
assert model is not None, f"Could not resolve model {model_name}"
|
||||||
model_dir = model_local_dir(model)
|
model_dir = model_local_dir(model.descriptor())
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue