local imports for faster cli

This commit is contained in:
Hardik Shah 2024-08-02 16:34:29 -07:00
parent af4710c959
commit 67229f23a4
9 changed files with 44 additions and 47 deletions

View file

@ -18,15 +18,8 @@ from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
from .utils import run_command
class DistributionConfigure(Subcommand): class DistributionConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs""" """Llama cli for configuring llama toolchain configs"""
@ -43,6 +36,7 @@ class DistributionConfigure(Subcommand):
self.parser.set_defaults(func=self._run_distribution_configure_cmd) self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self): def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
@ -52,6 +46,8 @@ class DistributionConfigure(Subcommand):
) )
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution
dist = resolve_distribution(args.name) dist = resolve_distribution(args.name)
if dist is None: if dist is None:
self.parser.error(f"Could not find distribution {args.name}") self.parser.error(f"Could not find distribution {args.name}")
@ -66,7 +62,10 @@ class DistributionConfigure(Subcommand):
configure_llama_distribution(dist, conda_env) configure_llama_distribution(dist, conda_env)
def configure_llama_distribution(dist: Distribution, conda_env: str): def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from .utils import run_command
python_exe = run_command(shlex.split("which python")) python_exe = run_command(shlex.split("which python"))
# simple check # simple check
if conda_env not in python_exe: if conda_env not in python_exe:

View file

@ -7,7 +7,6 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
class DistributionCreate(Subcommand): class DistributionCreate(Subcommand):
@ -35,6 +34,8 @@ class DistributionCreate(Subcommand):
# wants to pick and then ask for their configuration. # wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution
dist = resolve_distribution(args.name) dist = resolve_distribution(args.name)
if dist is not None: if dist is not None:
self.parser.error(f"Distribution with name {args.name} already exists") self.parser.error(f"Distribution with name {args.name} already exists")

View file

@ -11,17 +11,8 @@ import shlex
import pkg_resources import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty
DISTRIBS = available_distributions()
class DistributionInstall(Subcommand): class DistributionInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs""" """Llama cli for configuring llama toolchain configs"""
@ -38,12 +29,13 @@ class DistributionInstall(Subcommand):
self.parser.set_defaults(func=self._run_distribution_install_cmd) self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self): def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
help="Name of the distribution to install -- (try local-ollama)", help="Name of the distribution to install -- (try local-ollama)",
required=True, required=True,
choices=[d.name for d in DISTRIBS], choices=[d.name for d in available_distributions()],
) )
self.parser.add_argument( self.parser.add_argument(
"--conda-env", "--conda-env",
@ -53,6 +45,10 @@ class DistributionInstall(Subcommand):
) )
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution
from .utils import run_command, run_with_pty
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
"llama_toolchain", "llama_toolchain",

View file

@ -7,10 +7,6 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
class DistributionList(Subcommand): class DistributionList(Subcommand):
@ -30,6 +26,10 @@ class DistributionList(Subcommand):
pass pass
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [
"Name", "Name",

View file

@ -11,12 +11,8 @@ from pathlib import Path
import yaml import yaml
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.server import main as distribution_server_init
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
class DistributionStart(Subcommand): class DistributionStart(Subcommand):
@ -52,6 +48,10 @@ class DistributionStart(Subcommand):
) )
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.server import main as distribution_server_init
from .utils import run_command
dist = resolve_distribution(args.name) dist = resolve_distribution(args.name)
if dist is None: if dist is None:
self.parser.error(f"Distribution with name {args.name} not found") self.parser.error(f"Distribution with name {args.name} not found")

View file

@ -13,15 +13,6 @@ from pathlib import Path
import httpx import httpx
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_models.datatypes import Model
from llama_models.sku_list import (
all_registered_models,
llama_meta_net_info,
resolve_model,
)
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
@ -46,6 +37,8 @@ class Download(Subcommand):
self.parser.set_defaults(func=self._run_download_cmd) self.parser.set_defaults(func=self._run_download_cmd)
def _add_arguments(self): def _add_arguments(self):
from llama_models.sku_list import all_registered_models
models = all_registered_models() models = all_registered_models()
self.parser.add_argument( self.parser.add_argument(
"--source", "--source",
@ -81,7 +74,10 @@ safetensors files to avoid downloading duplicate weights.
""", """,
) )
def _hf_download(self, model: Model, hf_token: str, ignore_patterns: str): def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
repo_id = model.huggingface_repo repo_id = model.huggingface_repo
if repo_id is None: if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}") raise ValueError(f"No repo id found for model {model.descriptor()}")
@ -112,7 +108,9 @@ safetensors files to avoid downloading duplicate weights.
print(f"Successfully downloaded model to {true_output_dir}") print(f"Successfully downloaded model to {true_output_dir}")
def _meta_download(self, model: Model, meta_url: str): def _meta_download(self, model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -128,6 +126,8 @@ safetensors files to avoid downloading duplicate weights.
asyncio.run(downloader.download()) asyncio.run(downloader.download())
def _run_download_cmd(self, args: argparse.Namespace): def _run_download_cmd(self, args: argparse.Namespace):
from llama_models.sku_list import resolve_model
model = resolve_model(args.model_id) model = resolve_model(args.model_id)
if model is None: if model is None:
self.parser.error(f"Model {args.model_id} not found") self.parser.error(f"Model {args.model_id} not found")

View file

@ -8,8 +8,6 @@ import argparse
from .distribution import DistributionParser from .distribution import DistributionParser
from .download import Download from .download import Download
# from .inference import InferenceParser
from .model import ModelParser from .model import ModelParser
@ -30,7 +28,6 @@ class LlamaCLIParser:
# Add sub-commands # Add sub-commands
Download.create(subparsers) Download.create(subparsers)
# InferenceParser.create(subparsers)
ModelParser.create(subparsers) ModelParser.create(subparsers)
DistributionParser.create(subparsers) DistributionParser.create(subparsers)

View file

@ -7,14 +7,9 @@
import argparse import argparse
import textwrap import textwrap
from llama_models.llama3_1.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from termcolor import colored from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
class ModelTemplate(Subcommand): class ModelTemplate(Subcommand):
@ -53,6 +48,12 @@ class ModelTemplate(Subcommand):
) )
def _run_model_template_cmd(self, args: argparse.Namespace) -> None: def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3_1.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.table import print_table
if args.name: if args.name:
template, tokens_info = render_jinja_template(args.name) template, tokens_info = render_jinja_template(args.name)
rendered = "" rendered = ""

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from functools import lru_cache
from typing import List, Optional from typing import List, Optional
from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.inference.adapters import available_inference_adapters
@ -27,7 +28,7 @@ COMMON_DEPENDENCIES = [
"hydra-core", "hydra-core",
"hydra-zen", "hydra-zen",
"json-strong-typing", "json-strong-typing",
"llama-models", "git+ssh://git@github.com/meta-llama/llama-models.git",
"omegaconf", "omegaconf",
"pandas", "pandas",
"Pillow", "Pillow",
@ -43,6 +44,7 @@ COMMON_DEPENDENCIES = [
] ]
@lru_cache()
def available_distributions() -> List[Distribution]: def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
@ -66,6 +68,7 @@ def available_distributions() -> List[Distribution]:
] ]
@lru_cache()
def resolve_distribution(name: str) -> Optional[Distribution]: def resolve_distribution(name: str) -> Optional[Distribution]:
for dist in available_distributions(): for dist in available_distributions():
if dist.name == name: if dist.name == name: