From dac2b5a1ed3f631f61757c8d817f4e6bdcd38a7b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 1 Aug 2024 16:40:43 -0700 Subject: [PATCH] More progress towards `llama distribution install` --- llama_toolchain/cli/distribution/configure.py | 92 +++++++++++++++++++ .../cli/distribution/distribution.py | 2 + llama_toolchain/cli/distribution/install.py | 90 ++++++++---------- llama_toolchain/cli/download.py | 4 +- llama_toolchain/cli/inference/configure.py | 4 +- llama_toolchain/distribution/datatypes.py | 80 +++++++++++++++- .../distribution/install_distribution.sh | 64 +++++++++++++ .../distribution/local-ollama/install.sh | 5 - llama_toolchain/distribution/local/install.sh | 5 - llama_toolchain/distribution/registry.py | 23 +++-- llama_toolchain/utils.py | 4 +- 11 files changed, 298 insertions(+), 75 deletions(-) create mode 100644 llama_toolchain/cli/distribution/configure.py create mode 100755 llama_toolchain/distribution/install_distribution.sh delete mode 100644 llama_toolchain/distribution/local-ollama/install.sh delete mode 100644 llama_toolchain/distribution/local/install.sh diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py new file mode 100644 index 000000000..1e0712b4a --- /dev/null +++ b/llama_toolchain/cli/distribution/configure.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os + +from pathlib import Path + +import pkg_resources + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.registry import all_registered_distributions +from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR + + +CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") + + +class DistributionConfigure(Subcommand): + """Llama cli for configuring llama toolchain configs""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "configure", + prog="llama distribution configure", + description="configure a llama stack distribution", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_distribution_configure_cmd) + + def _add_arguments(self): + distribs = all_registered_distributions() + self.parser.add_argument( + "--name", + type=str, + help="Mame of the distribution to configure", + default="local-source", + choices=[d.name for d in distribs], + ) + + def read_user_inputs(self): + checkpoint_dir = input( + "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " + ) + model_parallel_size = input( + "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " + ) + assert model_parallel_size.isdigit() and int(model_parallel_size) in { + 1, + 8, + }, "model parallel size must be 1 or 8" + + return checkpoint_dir, model_parallel_size + + def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): + default_conf_path = pkg_resources.resource_filename( + "llama_toolchain", "data/default_distribution_config.yaml" + ) + with open(default_conf_path, "r") as f: + yaml_content = f.read() + + yaml_content = yaml_content.format( + checkpoint_dir=checkpoint_dir, + model_parallel_size=model_parallel_size, + ) + + with open(yaml_output_path, "w") as yaml_file: + yaml_file.write(yaml_content.strip()) + + print(f"YAML configuration has been written to {yaml_output_path}") + + def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: + checkpoint_dir, model_parallel_size = self.read_user_inputs() + checkpoint_dir = os.path.expanduser(checkpoint_dir) + + assert ( + Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() + ), f"{checkpoint_dir} does not exist or it not a directory" + + os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) + yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml" + + self.write_output_yaml( + checkpoint_dir, + model_parallel_size, + yaml_output_path, + ) diff --git a/llama_toolchain/cli/distribution/distribution.py b/llama_toolchain/cli/distribution/distribution.py index f8482de87..02a0b8caf 100644 --- a/llama_toolchain/cli/distribution/distribution.py +++ b/llama_toolchain/cli/distribution/distribution.py @@ -7,6 +7,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand +from .configure import DistributionConfigure from .create import DistributionCreate from .install import DistributionInstall from .list import DistributionList @@ -28,3 +29,4 @@ class DistributionParser(Subcommand): DistributionList.create(subparsers) DistributionInstall.create(subparsers) DistributionCreate.create(subparsers) + DistributionConfigure.create(subparsers) diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index f0760b07b..d8bbcb599 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -6,6 +6,8 @@ import argparse import os +import shlex +import subprocess from pathlib import Path @@ -13,10 +15,12 @@ import pkg_resources from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.distribution.registry import all_registered_distributions -from llama_toolchain.utils import DEFAULT_DUMP_DIR +from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR -CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") +DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions" + +DISTRIBS = all_registered_distributions() class DistributionInstall(Subcommand): @@ -34,59 +38,45 @@ class DistributionInstall(Subcommand): self.parser.set_defaults(func=self._run_distribution_install_cmd) def _add_arguments(self): - distribs = all_registered_distributions() self.parser.add_argument( "--name", type=str, - help="Mame of the distribution to install", - default="local-source", - choices=[d.name for d in distribs], + help="Name of the distribution to install -- (try local-ollama)", + required=True, + choices=[d.name for d in DISTRIBS], ) - - def read_user_inputs(self): - checkpoint_dir = input( - "Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " + self.parser.add_argument( + "--conda-env", + type=str, + help="Specify the name of the conda environment you would like to create or update", + required=True, ) - model_parallel_size = input( - "Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): " - ) - assert model_parallel_size.isdigit() and int(model_parallel_size) in { - 1, - 8, - }, "model parallel size must be 1 or 8" - - return checkpoint_dir, model_parallel_size - - def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path): - default_conf_path = pkg_resources.resource_filename( - "llama_toolchain", "data/default_distribution_config.yaml" - ) - with open(default_conf_path, "r") as f: - yaml_content = f.read() - - yaml_content = yaml_content.format( - checkpoint_dir=checkpoint_dir, - model_parallel_size=model_parallel_size, - ) - - with open(yaml_output_path, "w") as yaml_file: - yaml_file.write(yaml_content.strip()) - - print(f"YAML configuration has been written to {yaml_output_path}") def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: - checkpoint_dir, model_parallel_size = self.read_user_inputs() - checkpoint_dir = os.path.expanduser(checkpoint_dir) - - assert ( - Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir() - ), f"{checkpoint_dir} does not exist or it not a directory" - - os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) - yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml" - - self.write_output_yaml( - checkpoint_dir, - model_parallel_size, - yaml_output_path, + os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/install_distribution.sh", ) + + dist = None + for d in DISTRIBS: + if d.name == args.name: + dist = d + break + + if dist is None: + self.parser.error(f"Could not find distribution {args.name}") + return + + os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True) + run_shell_script(script, args.conda_env, " ".join(dist.pip_packages)) + with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f: + f.write(f"{args.conda_env}\n") + + +def run_shell_script(script_path, *args): + command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}" + command_list = shlex.split(command_string) + print(f"Running command: {command_list}") + subprocess.run(command_list, check=True, text=True) diff --git a/llama_toolchain/cli/download.py b/llama_toolchain/cli/download.py index d100cee61..b71738bb7 100644 --- a/llama_toolchain/cli/download.py +++ b/llama_toolchain/cli/download.py @@ -25,10 +25,10 @@ from llama_models.sku_list import ( from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import DEFAULT_DUMP_DIR +from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR -DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints") +DEFAULT_CHECKPOINT_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "checkpoints") class Download(Subcommand): diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py index fb7a309a9..1a511ea62 100644 --- a/llama_toolchain/cli/inference/configure.py +++ b/llama_toolchain/cli/inference/configure.py @@ -13,10 +13,10 @@ from pathlib import Path import pkg_resources from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.utils import DEFAULT_DUMP_DIR +from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR -CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") +CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") class InferenceConfigure(Subcommand): diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 753ce3b6d..5cef08dcc 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -4,12 +4,84 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List +from enum import Enum +from typing import Any, Dict, List, Literal, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field +from strong_typing.schema import json_schema_type +from typing_extensions import Annotated -class LlamaStackDistribution(BaseModel): +@json_schema_type +class AdapterType(Enum): + passthrough_api = "passthrough_api" + python_impl = "python_impl" + not_implemented = "not_implemented" + + +@json_schema_type +class PassthroughApiAdapterConfig(BaseModel): + type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value + base_url: str = Field(..., description="The base URL for the llama stack provider") + headers: Dict[str, str] = Field( + default_factory=dict, + description="Headers (e.g., authorization) to send with the request", + ) + + +@json_schema_type +class PythonImplAdapterConfig(BaseModel): + type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + module: str = Field(..., description="The name of the module to import") + entrypoint: str = Field( + ..., + description="The name of the entrypoint function which creates the implementation for the API", + ) + kwargs: Dict[str, Any] = Field( + default_factory=dict, description="kwargs to pass to the entrypoint" + ) + + +@json_schema_type +class NotImplementedAdapterConfig(BaseModel): + type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value + + +# should we define very granular typed classes for each of the PythonImplAdapters we will have? +# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters +AdapterConfig = Annotated[ + Union[ + PassthroughApiAdapterConfig, + NotImplementedAdapterConfig, + PythonImplAdapterConfig, + ], + Field(discriminator="type"), +] + + +class DistributionConfig(BaseModel): + inference: AdapterConfig + safety: AdapterConfig + + # configs for each API that the stack provides, e.g. + # agentic_system: AdapterConfig + # post_training: AdapterConfig + + +class DistributionConfigDefaults(BaseModel): + inference: Dict[str, Any] = Field( + default_factory=dict, description="Default kwargs for the inference adapter" + ) + safety: Dict[str, Any] = Field( + default_factory=dict, description="Default kwargs for the safety adapter" + ) + + +class Distribution(BaseModel): name: str description: str @@ -17,3 +89,5 @@ class LlamaStackDistribution(BaseModel): # later, we may have a docker image be the main artifact of # a distribution. pip_packages: List[str] + + config_defaults: DistributionConfigDefaults diff --git a/llama_toolchain/distribution/install_distribution.sh b/llama_toolchain/distribution/install_distribution.sh new file mode 100755 index 000000000..0707f4d6b --- /dev/null +++ b/llama_toolchain/distribution/install_distribution.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +error_handler() { + echo "Error occurred in script at line: ${1}" >&2 + exit 1 +} + +# Set up the error trap +trap 'error_handler ${LINENO}' ERR + +ensure_conda_env_python310() { + local env_name="$1" + local pip_dependencies="$2" + local python_version="3.10" + + # Check if conda command is available + if ! command -v conda &>/dev/null; then + echo "Error: conda command not found. Is Conda installed and in your PATH?" >&2 + exit 1 + fi + + # Check if the environment exists + if conda env list | grep -q "^${env_name} "; then + echo "Conda environment '${env_name}' exists. Checking Python version..." + + # Check Python version in the environment + current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) + + if [ "$current_version" = "$python_version" ]; then + echo "Environment '${env_name}' already has Python ${python_version}. No action needed." + else + echo "Updating environment '${env_name}' to Python ${python_version}..." + conda install -n "${env_name}" python="${python_version}" -y + fi + else + echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..." + conda create -n "${env_name}" python="${python_version}" -y + fi + + # Install pip dependencies + if [ -n "$pip_dependencies" ]; then + echo "Installing pip dependencies: $pip_dependencies" + conda run -n "${env_name}" pip install $pip_dependencies + fi +} + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " >&2 + echo "Example: $0 my_env 'numpy pandas scipy'" >&2 + exit 1 +fi + +env_name="$1" +pip_dependencies="$2" + +ensure_conda_env_python310 "$env_name" "$pip_dependencies" diff --git a/llama_toolchain/distribution/local-ollama/install.sh b/llama_toolchain/distribution/local-ollama/install.sh deleted file mode 100644 index 756f351d8..000000000 --- a/llama_toolchain/distribution/local-ollama/install.sh +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_toolchain/distribution/local/install.sh b/llama_toolchain/distribution/local/install.sh deleted file mode 100644 index 756f351d8..000000000 --- a/llama_toolchain/distribution/local/install.sh +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 6abce380f..a1f9a7a55 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -6,19 +6,30 @@ from typing import List -from .datatypes import LlamaStackDistribution +from .datatypes import Distribution, DistributionConfigDefaults -def all_registered_distributions() -> List[LlamaStackDistribution]: +def all_registered_distributions() -> List[Distribution]: return [ - LlamaStackDistribution( + Distribution( name="local-source", - description="Use code within `llama_toolchain` itself to run model inference and everything on top", + description="Use code from `llama_toolchain` itself to serve all llama stack APIs", pip_packages=[], + config_defaults=DistributionConfigDefaults( + inference={ + "max_seq_len": 4096, + "max_batch_size": 1, + }, + safety={}, + ), ), - LlamaStackDistribution( + Distribution( name="local-ollama", description="Like local-source, but use ollama for running LLM inference", - pip_packages=[], + pip_packages=["ollama"], + config_defaults=DistributionConfigDefaults( + inference={}, + safety={}, + ), ), ] diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py index 89f5c070a..2bf3be4e3 100644 --- a/llama_toolchain/utils.py +++ b/llama_toolchain/utils.py @@ -14,7 +14,7 @@ from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf -DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/") +LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") def get_root_directory(): @@ -26,7 +26,7 @@ def get_root_directory(): def get_default_config_dir(): - return os.path.join(DEFAULT_DUMP_DIR, "configs") + return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: