mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
local imports for faster cli
This commit is contained in:
parent
af4710c959
commit
67229f23a4
9 changed files with 44 additions and 47 deletions
|
@ -18,15 +18,8 @@ from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
|
|
||||||
from llama_toolchain.distribution.registry import (
|
|
||||||
available_distributions,
|
|
||||||
resolve_distribution,
|
|
||||||
)
|
|
||||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
|
||||||
|
|
||||||
from .utils import run_command
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionConfigure(Subcommand):
|
class DistributionConfigure(Subcommand):
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
"""Llama cli for configuring llama toolchain configs"""
|
||||||
|
@ -43,6 +36,7 @@ class DistributionConfigure(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.distribution.registry import available_distributions
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -52,6 +46,8 @@ class DistributionConfigure(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
dist = resolve_distribution(args.name)
|
||||||
if dist is None:
|
if dist is None:
|
||||||
self.parser.error(f"Could not find distribution {args.name}")
|
self.parser.error(f"Could not find distribution {args.name}")
|
||||||
|
@ -66,7 +62,10 @@ class DistributionConfigure(Subcommand):
|
||||||
configure_llama_distribution(dist, conda_env)
|
configure_llama_distribution(dist, conda_env)
|
||||||
|
|
||||||
|
|
||||||
def configure_llama_distribution(dist: Distribution, conda_env: str):
|
def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
||||||
|
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
|
||||||
|
from .utils import run_command
|
||||||
|
|
||||||
python_exe = run_command(shlex.split("which python"))
|
python_exe = run_command(shlex.split("which python"))
|
||||||
# simple check
|
# simple check
|
||||||
if conda_env not in python_exe:
|
if conda_env not in python_exe:
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionCreate(Subcommand):
|
class DistributionCreate(Subcommand):
|
||||||
|
@ -35,6 +34,8 @@ class DistributionCreate(Subcommand):
|
||||||
# wants to pick and then ask for their configuration.
|
# wants to pick and then ask for their configuration.
|
||||||
|
|
||||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
dist = resolve_distribution(args.name)
|
||||||
if dist is not None:
|
if dist is not None:
|
||||||
self.parser.error(f"Distribution with name {args.name} already exists")
|
self.parser.error(f"Distribution with name {args.name} already exists")
|
||||||
|
|
|
@ -11,17 +11,8 @@ import shlex
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
|
||||||
from llama_toolchain.distribution.registry import (
|
|
||||||
available_distributions,
|
|
||||||
resolve_distribution,
|
|
||||||
)
|
|
||||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from .utils import run_command, run_with_pty
|
|
||||||
|
|
||||||
DISTRIBS = available_distributions()
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionInstall(Subcommand):
|
class DistributionInstall(Subcommand):
|
||||||
"""Llama cli for configuring llama toolchain configs"""
|
"""Llama cli for configuring llama toolchain configs"""
|
||||||
|
@ -38,12 +29,13 @@ class DistributionInstall(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
|
from llama_toolchain.distribution.registry import available_distributions
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the distribution to install -- (try local-ollama)",
|
help="Name of the distribution to install -- (try local-ollama)",
|
||||||
required=True,
|
required=True,
|
||||||
choices=[d.name for d in DISTRIBS],
|
choices=[d.name for d in available_distributions()],
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--conda-env",
|
"--conda-env",
|
||||||
|
@ -53,6 +45,10 @@ class DistributionInstall(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
from .utils import run_command, run_with_pty
|
||||||
|
|
||||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_toolchain",
|
"llama_toolchain",
|
||||||
|
|
|
@ -7,10 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.cli.table import print_table
|
|
||||||
|
|
||||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionList(Subcommand):
|
class DistributionList(Subcommand):
|
||||||
|
@ -30,6 +26,10 @@ class DistributionList(Subcommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.cli.table import print_table
|
||||||
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
|
from llama_toolchain.distribution.registry import available_distributions
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Name",
|
"Name",
|
||||||
|
|
|
@ -11,12 +11,8 @@ from pathlib import Path
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.registry import resolve_distribution
|
|
||||||
from llama_toolchain.distribution.server import main as distribution_server_init
|
|
||||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from .utils import run_command
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionStart(Subcommand):
|
class DistributionStart(Subcommand):
|
||||||
|
|
||||||
|
@ -52,6 +48,10 @@ class DistributionStart(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
from llama_toolchain.distribution.server import main as distribution_server_init
|
||||||
|
from .utils import run_command
|
||||||
|
|
||||||
dist = resolve_distribution(args.name)
|
dist = resolve_distribution(args.name)
|
||||||
if dist is None:
|
if dist is None:
|
||||||
self.parser.error(f"Distribution with name {args.name} not found")
|
self.parser.error(f"Distribution with name {args.name} not found")
|
||||||
|
|
|
@ -13,15 +13,6 @@ from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
|
||||||
|
|
||||||
from llama_models.datatypes import Model
|
|
||||||
from llama_models.sku_list import (
|
|
||||||
all_registered_models,
|
|
||||||
llama_meta_net_info,
|
|
||||||
resolve_model,
|
|
||||||
)
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
@ -46,6 +37,8 @@ class Download(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_download_cmd)
|
self.parser.set_defaults(func=self._run_download_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
models = all_registered_models()
|
models = all_registered_models()
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--source",
|
"--source",
|
||||||
|
@ -81,7 +74,10 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _hf_download(self, model: Model, hf_token: str, ignore_patterns: str):
|
def _hf_download(self, model: "Model", hf_token: str, ignore_patterns: str):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||||
|
|
||||||
repo_id = model.huggingface_repo
|
repo_id = model.huggingface_repo
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||||
|
@ -112,7 +108,9 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
|
|
||||||
print(f"Successfully downloaded model to {true_output_dir}")
|
print(f"Successfully downloaded model to {true_output_dir}")
|
||||||
|
|
||||||
def _meta_download(self, model: Model, meta_url: str):
|
def _meta_download(self, model: "Model", meta_url: str):
|
||||||
|
from llama_models.sku_list import llama_meta_net_info
|
||||||
|
|
||||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
@ -128,6 +126,8 @@ safetensors files to avoid downloading duplicate weights.
|
||||||
asyncio.run(downloader.download())
|
asyncio.run(downloader.download())
|
||||||
|
|
||||||
def _run_download_cmd(self, args: argparse.Namespace):
|
def _run_download_cmd(self, args: argparse.Namespace):
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
self.parser.error(f"Model {args.model_id} not found")
|
self.parser.error(f"Model {args.model_id} not found")
|
||||||
|
|
|
@ -8,8 +8,6 @@ import argparse
|
||||||
|
|
||||||
from .distribution import DistributionParser
|
from .distribution import DistributionParser
|
||||||
from .download import Download
|
from .download import Download
|
||||||
|
|
||||||
# from .inference import InferenceParser
|
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +28,6 @@ class LlamaCLIParser:
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
# InferenceParser.create(subparsers)
|
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
DistributionParser.create(subparsers)
|
DistributionParser.create(subparsers)
|
||||||
|
|
||||||
|
|
|
@ -7,14 +7,9 @@
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_models.llama3_1.api.interface import (
|
|
||||||
list_jinja_templates,
|
|
||||||
render_jinja_template,
|
|
||||||
)
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.cli.table import print_table
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTemplate(Subcommand):
|
class ModelTemplate(Subcommand):
|
||||||
|
@ -53,6 +48,12 @@ class ModelTemplate(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_models.llama3_1.api.interface import (
|
||||||
|
list_jinja_templates,
|
||||||
|
render_jinja_template,
|
||||||
|
)
|
||||||
|
from llama_toolchain.cli.table import print_table
|
||||||
|
|
||||||
if args.name:
|
if args.name:
|
||||||
template, tokens_info = render_jinja_template(args.name)
|
template, tokens_info = render_jinja_template(args.name)
|
||||||
rendered = ""
|
rendered = ""
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||||
|
@ -27,7 +28,7 @@ COMMON_DEPENDENCIES = [
|
||||||
"hydra-core",
|
"hydra-core",
|
||||||
"hydra-zen",
|
"hydra-zen",
|
||||||
"json-strong-typing",
|
"json-strong-typing",
|
||||||
"llama-models",
|
"git+ssh://git@github.com/meta-llama/llama-models.git",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"pandas",
|
"pandas",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
|
@ -43,6 +44,7 @@ COMMON_DEPENDENCIES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def available_distributions() -> List[Distribution]:
|
def available_distributions() -> List[Distribution]:
|
||||||
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
|
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
|
||||||
|
|
||||||
|
@ -66,6 +68,7 @@ def available_distributions() -> List[Distribution]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def resolve_distribution(name: str) -> Optional[Distribution]:
|
def resolve_distribution(name: str) -> Optional[Distribution]:
|
||||||
for dist in available_distributions():
|
for dist in available_distributions():
|
||||||
if dist.name == name:
|
if dist.name == name:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue