diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 88666f67e..20dc6955c 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -18,15 +18,8 @@ from pydantic import BaseModel from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter -from llama_toolchain.distribution.registry import ( - available_distributions, - resolve_distribution, -) from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder -from .utils import run_command - class DistributionConfigure(Subcommand): """Llama cli for configuring llama toolchain configs""" @@ -43,6 +36,7 @@ class DistributionConfigure(Subcommand): self.parser.set_defaults(func=self._run_distribution_configure_cmd) def _add_arguments(self): + from llama_toolchain.distribution.registry import available_distributions self.parser.add_argument( "--name", type=str, @@ -52,6 +46,8 @@ class DistributionConfigure(Subcommand): ) def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.registry import resolve_distribution + dist = resolve_distribution(args.name) if dist is None: self.parser.error(f"Could not find distribution {args.name}") @@ -66,7 +62,10 @@ class DistributionConfigure(Subcommand): configure_llama_distribution(dist, conda_env) -def configure_llama_distribution(dist: Distribution, conda_env: str): +def configure_llama_distribution(dist: "Distribution", conda_env: str): + from llama_toolchain.distribution.datatypes import PassthroughApiAdapter + from .utils import run_command + python_exe = run_command(shlex.split("which python")) # simple check if conda_env not in python_exe: diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index d169e46cc..e1cff1244 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -7,7 +7,6 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.registry import resolve_distribution class DistributionCreate(Subcommand): @@ -35,6 +34,8 @@ class DistributionCreate(Subcommand): # wants to pick and then ask for their configuration. def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.registry import resolve_distribution + dist = resolve_distribution(args.name) if dist is not None: self.parser.error(f"Distribution with name {args.name} already exists") diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index d45456f75..367906e32 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -11,17 +11,8 @@ import shlex import pkg_resources from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.distribution import distribution_dependencies -from llama_toolchain.distribution.registry import ( - available_distributions, - resolve_distribution, -) from llama_toolchain.utils import DISTRIBS_BASE_DIR -from .utils import run_command, run_with_pty - -DISTRIBS = available_distributions() - class DistributionInstall(Subcommand): """Llama cli for configuring llama toolchain configs""" @@ -38,12 +29,13 @@ class DistributionInstall(Subcommand): self.parser.set_defaults(func=self._run_distribution_install_cmd) def _add_arguments(self): + from llama_toolchain.distribution.registry import available_distributions self.parser.add_argument( "--name", type=str, help="Name of the distribution to install -- (try local-ollama)", required=True, - choices=[d.name for d in DISTRIBS], + choices=[d.name for d in available_distributions()], ) self.parser.add_argument( "--conda-env", @@ -53,6 +45,10 @@ class DistributionInstall(Subcommand): ) def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.distribution import distribution_dependencies + from llama_toolchain.distribution.registry import resolve_distribution + from .utils import run_command, run_with_pty + os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) script = pkg_resources.resource_filename( "llama_toolchain", diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index d20980432..c13c11e66 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -7,10 +7,6 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.cli.table import print_table - -from llama_toolchain.distribution.distribution import distribution_dependencies -from llama_toolchain.distribution.registry import available_distributions class DistributionList(Subcommand): @@ -30,6 +26,10 @@ class DistributionList(Subcommand): pass def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.cli.table import print_table + from llama_toolchain.distribution.distribution import distribution_dependencies + from llama_toolchain.distribution.registry import available_distributions + # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ "Name", diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py index b567726db..04caeca51 100644 --- a/llama_toolchain/cli/distribution/start.py +++ b/llama_toolchain/cli/distribution/start.py @@ -11,12 +11,8 @@ from pathlib import Path import yaml from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.registry import resolve_distribution -from llama_toolchain.distribution.server import main as distribution_server_init from llama_toolchain.utils import DISTRIBS_BASE_DIR -from .utils import run_command - class DistributionStart(Subcommand): @@ -52,6 +48,10 @@ class DistributionStart(Subcommand): ) def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.registry import resolve_distribution + from llama_toolchain.distribution.server import main as distribution_server_init + from .utils import run_command + dist = resolve_distribution(args.name) if dist is None: self.parser.error(f"Distribution with name {args.name} not found") diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index b71738bb7..892af927a 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -13,15 +13,6 @@ from pathlib import Path import httpx -from huggingface_hub import snapshot_download -from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - -from llama_models.datatypes import Model -from llama_models.sku_list import ( - all_registered_models, - llama_meta_net_info, - resolve_model, -) from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand @@ -46,6 +37,8 @@ class Download(Subcommand): self.parser.set_defaults(func=self._run_download_cmd) def _add_arguments(self): + from llama_models.sku_list import all_registered_models + models = all_registered_models() self.parser.add_argument( "--source", @@ -81,7 +74,10 @@ safetensors files to avoid downloading duplicate weights. """, ) - def _hf_download(self, model: Model, hf_token: str, ignore_patterns: str): + def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str): + from huggingface_hub import snapshot_download + from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + repo_id = model.huggingface_repo if repo_id is None: raise ValueError(f"No repo id found for model {model.descriptor()}") @@ -112,7 +108,9 @@ safetensors files to avoid downloading duplicate weights. print(f"Successfully downloaded model to {true_output_dir}") - def _meta_download(self, model: Model, meta_url: str): + def _meta_download(self, model: "Model", meta_url: str): + from llama_models.sku_list import llama_meta_net_info + output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor() os.makedirs(output_dir, exist_ok=True) @@ -128,6 +126,8 @@ safetensors files to avoid downloading duplicate weights. asyncio.run(downloader.download()) def _run_download_cmd(self, args: argparse.Namespace): + from llama_models.sku_list import resolve_model + model = resolve_model(args.model_id) if model is None: self.parser.error(f"Model {args.model_id} not found") diff --git a/llama_toolchain/cli/llama.py b/llama_toolchain/cli/llama.py index 4764cf32e..5ff11ae84 100644 --- a/llama_toolchain/cli/llama.py +++ b/llama_toolchain/cli/llama.py @@ -8,8 +8,6 @@ import argparse from .distribution import DistributionParser from .download import Download - -# from .inference import InferenceParser from .model import ModelParser @@ -30,7 +28,6 @@ class LlamaCLIParser: # Add sub-commands Download.create(subparsers) - # InferenceParser.create(subparsers) ModelParser.create(subparsers) DistributionParser.create(subparsers) diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 1fdeab3e7..c0ba60882 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -7,14 +7,9 @@ import argparse import textwrap -from llama_models.llama3_1.api.interface import ( - list_jinja_templates, - render_jinja_template, -) from termcolor import colored from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.cli.table import print_table class ModelTemplate(Subcommand): @@ -53,6 +48,12 @@ class ModelTemplate(Subcommand): ) def _run_model_template_cmd(self, args: argparse.Namespace) -> None: + from llama_models.llama3_1.api.interface import ( + list_jinja_templates, + render_jinja_template, + ) + from llama_toolchain.cli.table import print_table + if args.name: template, tokens_info = render_jinja_template(args.name) rendered = "" diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 48124c7d1..ef27a2bbc 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from functools import lru_cache from typing import List, Optional from llama_toolchain.inference.adapters import available_inference_adapters @@ -27,7 +28,7 @@ COMMON_DEPENDENCIES = [ "hydra-core", "hydra-zen", "json-strong-typing", - "llama-models", + "git+ssh://git@github.com/meta-llama/llama-models.git", "omegaconf", "pandas", "Pillow", @@ -43,6 +44,7 @@ COMMON_DEPENDENCIES = [ ] +@lru_cache() def available_distributions() -> List[Distribution]: inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()} @@ -66,6 +68,7 @@ def available_distributions() -> List[Distribution]: ] +@lru_cache() def resolve_distribution(name: str) -> Optional[Distribution]: for dist in available_distributions(): if dist.name == name: