diff --git a/llama_toolchain/cli/distribution/__init__.py b/llama_toolchain/cli/distribution/__init__.py new file mode 100644 index 000000000..81278f253 --- /dev/null +++ b/llama_toolchain/cli/distribution/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .distribution import DistributionParser # noqa diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py new file mode 100644 index 000000000..98d3d47dd --- /dev/null +++ b/llama_toolchain/cli/distribution/create.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand + + +class DistributionCreate(Subcommand): + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "create", + prog="llama distribution create", + description="create a Llama stack distribution", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_distribution_create_cmd) + + def _add_arguments(self): + pass + + def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: + raise NotImplementedError() diff --git a/llama_toolchain/cli/distribution/distribution.py b/llama_toolchain/cli/distribution/distribution.py new file mode 100644 index 000000000..f8482de87 --- /dev/null +++ b/llama_toolchain/cli/distribution/distribution.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand +from .create import DistributionCreate +from .install import DistributionInstall +from .list import DistributionList + + +class DistributionParser(Subcommand): + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "distribution", + prog="llama distribution", + description="Operate on llama stack distributions", + ) + + subparsers = self.parser.add_subparsers(title="distribution_subcommands") + + # Add sub-commands + DistributionList.create(subparsers) + DistributionInstall.create(subparsers) + DistributionCreate.create(subparsers) diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py new file mode 100644 index 000000000..f0760b07b --- /dev/null +++ b/llama_toolchain/cli/distribution/install.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os + +from pathlib import Path + +import pkg_resources + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.registry import all_registered_distributions +from llama_toolchain.utils import DEFAULT_DUMP_DIR + + +CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") + + +class DistributionInstall(Subcommand): + """Llama cli for configuring llama toolchain configs""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "install", + prog="llama distribution install", + description="Install a llama stack distribution", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_distribution_install_cmd) + + def _add_arguments(self): + distribs = all_registered_distributions() + self.parser.add_argument( + "--name", + type=str, + help="Mame of the distribution to install", + default="local-source", + choices=[d.name for d in distribs], + ) + + def read_user_inputs(self): + checkpoint_dir = input( + "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " + ) + model_parallel_size = input( + "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " + ) + assert model_parallel_size.isdigit() and int(model_parallel_size) in { + 1, + 8, + }, "model parallel size must be 1 or 8" + + return checkpoint_dir, model_parallel_size + + def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): + default_conf_path = pkg_resources.resource_filename( + "llama_toolchain", "data/default_distribution_config.yaml" + ) + with open(default_conf_path, "r") as f: + yaml_content = f.read() + + yaml_content = yaml_content.format( + checkpoint_dir=checkpoint_dir, + model_parallel_size=model_parallel_size, + ) + + with open(yaml_output_path, "w") as yaml_file: + yaml_file.write(yaml_content.strip()) + + print(f"YAML configuration has been written to {yaml_output_path}") + + def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: + checkpoint_dir, model_parallel_size = self.read_user_inputs() + checkpoint_dir = os.path.expanduser(checkpoint_dir) + + assert ( + Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() + ), f"{checkpoint_dir} does not exist or it not a directory" + + os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) + yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml" + + self.write_output_yaml( + checkpoint_dir, + model_parallel_size, + yaml_output_path, + ) diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py new file mode 100644 index 000000000..4cf26980b --- /dev/null +++ b/llama_toolchain/cli/distribution/list.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.cli.table import print_table + +from llama_toolchain.distribution.registry import all_registered_distributions + + +class DistributionList(Subcommand): + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama distribution list", + description="Show available llama stack distributions", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_distribution_list_cmd) + + def _add_arguments(self): + pass + + def _run_distribution_list_cmd(self, args: argparse.Namespace) -> None: + # eventually, this should query a registry at llama.meta.com/llamastack/distributions + headers = [ + "Name", + "Description", + "Dependencies", + ] + + rows = [] + for dist in all_registered_distributions(): + rows.append( + [ + dist.name, + dist.description, + ", ".join(dist.pip_packages), + ] + ) + print_table( + rows, + headers, + separate_rows=True, + ) diff --git a/llama_toolchain/cli/inference/__init__.py b/llama_toolchain/cli/inference/__init__.py index 756f351d8..74f5fc120 100644 --- a/llama_toolchain/cli/inference/__init__.py +++ b/llama_toolchain/cli/inference/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .inference import InferenceParser # noqa diff --git a/llama_toolchain/cli/llama.py b/llama_toolchain/cli/llama.py index 5bdb7ca59..f77855277 100644 --- a/llama_toolchain/cli/llama.py +++ b/llama_toolchain/cli/llama.py @@ -6,9 +6,10 @@ import argparse -from llama_toolchain.cli.download import Download -from llama_toolchain.cli.inference.inference import InferenceParser -from llama_toolchain.cli.model.model import ModelParser +from .distribution import DistributionParser +from .download import Download +from .inference import InferenceParser +from .model import ModelParser class LlamaCLIParser: @@ -30,6 +31,7 @@ class LlamaCLIParser: Download.create(subparsers) InferenceParser.create(subparsers) ModelParser.create(subparsers) + DistributionParser.create(subparsers) # Import sub-commands from agentic_system if they exist try: diff --git a/llama_toolchain/cli/model/__init__.py b/llama_toolchain/cli/model/__init__.py index 756f351d8..db70364a9 100644 --- a/llama_toolchain/cli/model/__init__.py +++ b/llama_toolchain/cli/model/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .model import ModelParser # noqa diff --git a/llama_toolchain/distribution/__init__.py b/llama_toolchain/distribution/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/distribution/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py new file mode 100644 index 000000000..753ce3b6d --- /dev/null +++ b/llama_toolchain/distribution/datatypes.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from pydantic import BaseModel + + +class LlamaStackDistribution(BaseModel): + name: str + description: str + + # you must install the packages to get the functionality needed. + # later, we may have a docker image be the main artifact of + # a distribution. + pip_packages: List[str] diff --git a/llama_toolchain/distribution/local-ollama/install.sh b/llama_toolchain/distribution/local-ollama/install.sh new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/distribution/local-ollama/install.sh @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/distribution/local/install.sh b/llama_toolchain/distribution/local/install.sh new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/distribution/local/install.sh @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py new file mode 100644 index 000000000..6abce380f --- /dev/null +++ b/llama_toolchain/distribution/registry.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from .datatypes import LlamaStackDistribution + + +def all_registered_distributions() -> List[LlamaStackDistribution]: + return [ + LlamaStackDistribution( + name="local-source", + description="Use code within `llama_toolchain` itself to run model inference and everything on top", + pip_packages=[], + ), + LlamaStackDistribution( + name="local-ollama", + description="Like local-source, but use ollama for running LLM inference", + pip_packages=[], + ), + ]