local imports for faster cli

This commit is contained in:
Hardik Shah 2024-08-02 16:34:29 -07:00
parent af4710c959
commit 67229f23a4
9 changed files with 44 additions and 47 deletions

View file

@ -11,17 +11,8 @@ import shlex
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty
DISTRIBS = available_distributions()
class DistributionInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
@ -38,12 +29,13 @@ class DistributionInstall(Subcommand):
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to install -- (try local-ollama)",
required=True,
choices=[d.name for d in DISTRIBS],
choices=[d.name for d in available_distributions()],
)
self.parser.add_argument(
"--conda-env",
@ -53,6 +45,10 @@ class DistributionInstall(Subcommand):
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution
from .utils import run_command, run_with_pty
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",