llama-stack/llama_stack/cli/download.py
Ashwin Bharambe 9487ad8294
API Updates (#73)
* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
2024-09-17 19:51:35 -07:00

339 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
import shutil
import time
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, List
import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
class Download(Subcommand):
"""Llama cli for downloading llama toolchain assets"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"download",
prog="llama download",
description="Download a model from llama.meta.com or Hugging Face Hub",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_download_parser(self.parser)
def setup_download_parser(parser: argparse.ArgumentParser) -> None:
from llama_models.sku_list import all_registered_models
models = all_registered_models()
parser.add_argument(
"--source",
choices=["meta", "huggingface"],
required=True,
)
parser.add_argument(
"--model-id",
required=False,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
parser.add_argument(
"--hf-token",
type=str,
required=False,
default=None,
help="Hugging Face API token. Needed for gated models like llama2/3. Will also try to read environment variable `HF_TOKEN` as default.",
)
parser.add_argument(
"--meta-url",
type=str,
required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
)
parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="""
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights.
""",
)
parser.add_argument(
"--manifest-file",
type=str,
help="For source=meta, you can download models from a manifest file containing a file => URL mapping",
required=False,
)
parser.set_defaults(func=partial(run_download_cmd, parser=parser))
def _hf_download(
model: "Model",
hf_token: str,
ignore_patterns: str,
parser: argparse.ArgumentParser,
):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_stack.distribution.utils.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = model_local_dir(model.descriptor())
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
repo_id,
local_dir=output_dir,
ignore_patterns=ignore_patterns,
token=hf_token,
library_name="llama-stack",
)
except GatedRepoError:
parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
parser.error(f"Repository '{args.repo_id}' not found on the Hugging Face Hub.")
except Exception as e:
parser.error(e)
print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)
# I believe we can use some concurrency here if needed but not sure it is worth it
for f in info.files:
output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0
cprint(f"Downloading `{f}`...", "white")
downloader = ResumableDownloader(url, output_file, total_size)
asyncio.run(downloader.download())
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import resolve_model
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
meta_url = input(
"Please provide the signed URL you received via email (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url)
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
class Config:
protected_namespaces = ()
class Manifest(BaseModel):
models: List[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str):
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
d = json.load(f)
manifest = Manifest(**d)
if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
for entry in manifest.models:
print(f"Downloading model {entry.model_id}...")
output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red")
while True:
resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
)
if resp.lower() == "restart" or resp.lower() == "r":
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
break
elif resp.lower() == "continue" or resp.lower() == "c":
print("Continuing download...")
break
else:
cprint("Invalid response. Please try again.", "red")
for fname, url in entry.files.items():
output_file = str(output_dir / fname)
downloader = ResumableDownloader(url, output_file)
asyncio.run(downloader.download())
class ResumableDownloader:
def __init__(
self,
url: str,
output_file: str,
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
return
# Force disable compression when trying to retrieve file size
response = await client.head(
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"}
)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
self.total_size = int(response.headers.get("Content-Length", 0))
if self.total_size == 0:
raise ValueError(
"Unable to determine file size. The server might not support range requests."
)
async def download(self) -> None:
self.start_time = time.time()
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
if os.path.exists(self.output_file):
self.downloaded_size = os.path.getsize(self.output_file)
self.start_size = self.downloaded_size
if self.downloaded_size >= self.total_size:
print(f"Already downloaded `{self.output_file}`, skipping...")
return
additional_size = self.total_size - self.downloaded_size
if not self.has_disk_space(additional_size):
M = 1024 * 1024 # noqa
print(
f"Not enough disk space to download `{self.output_file}`. "
f"Required: {(additional_size // M):.2f} MB"
)
raise ValueError(
f"Not enough disk space to download `{self.output_file}`"
)
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100
bar_length = 50
filled_length = int(bar_length * self.downloaded_size // self.total_size)
bar = "" * filled_length + "-" * (bar_length - filled_length)
elapsed_time = time.time() - self.start_time
M = 1024 * 1024 # noqa
speed = (
(self.downloaded_size - self.start_size) / (elapsed_time * M)
if elapsed_time > 0
else 0
)
print(
f"\rProgress: |{bar}| {percent:.2f}% "
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
f"Speed: {speed:.2f} MiB/s",
end="",
flush=True,
)
def has_disk_space(self, file_size: int) -> bool:
dir_path = os.path.dirname(os.path.abspath(self.output_file))
free_space = shutil.disk_usage(dir_path).free
return free_space > file_size