More progress towards llama distribution install

This commit is contained in:
Ashwin Bharambe 2024-08-01 16:40:43 -07:00
parent 5a583cf16e
commit dac2b5a1ed
11 changed files with 298 additions and 75 deletions

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from pathlib import Path
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
class DistributionConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama distribution configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
def _add_arguments(self):
distribs = all_registered_distributions()
self.parser.add_argument(
"--name",
type=str,
help="Mame of the distribution to configure",
default="local-source",
choices=[d.name for d in distribs],
)
def read_user_inputs(self):
checkpoint_dir = input(
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
)
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_distribution_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
assert (
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
)

View file

@ -7,6 +7,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from .configure import DistributionConfigure
from .create import DistributionCreate from .create import DistributionCreate
from .install import DistributionInstall from .install import DistributionInstall
from .list import DistributionList from .list import DistributionList
@ -28,3 +29,4 @@ class DistributionParser(Subcommand):
DistributionList.create(subparsers) DistributionList.create(subparsers)
DistributionInstall.create(subparsers) DistributionInstall.create(subparsers)
DistributionCreate.create(subparsers) DistributionCreate.create(subparsers)
DistributionConfigure.create(subparsers)

View file

@ -6,6 +6,8 @@
import argparse import argparse
import os import os
import shlex
import subprocess
from pathlib import Path from pathlib import Path
@ -13,10 +15,12 @@ import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import all_registered_distributions from llama_toolchain.distribution.registry import all_registered_distributions
from llama_toolchain.utils import DEFAULT_DUMP_DIR from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
DISTRIBS = all_registered_distributions()
class DistributionInstall(Subcommand): class DistributionInstall(Subcommand):
@ -34,59 +38,45 @@ class DistributionInstall(Subcommand):
self.parser.set_defaults(func=self._run_distribution_install_cmd) self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self): def _add_arguments(self):
distribs = all_registered_distributions()
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
help="Mame of the distribution to install", help="Name of the distribution to install -- (try local-ollama)",
default="local-source", required=True,
choices=[d.name for d in distribs], choices=[d.name for d in DISTRIBS],
) )
self.parser.add_argument(
def read_user_inputs(self): "--conda-env",
checkpoint_dir = input( type=str,
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): " help="Specify the name of the conda environment you would like to create or update",
required=True,
) )
model_parallel_size = input(
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
)
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
1,
8,
}, "model parallel size must be 1 or 8"
return checkpoint_dir, model_parallel_size
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
default_conf_path = pkg_resources.resource_filename(
"llama_toolchain", "data/default_distribution_config.yaml"
)
with open(default_conf_path, "r") as f:
yaml_content = f.read()
yaml_content = yaml_content.format(
checkpoint_dir=checkpoint_dir,
model_parallel_size=model_parallel_size,
)
with open(yaml_output_path, "w") as yaml_file:
yaml_file.write(yaml_content.strip())
print(f"YAML configuration has been written to {yaml_output_path}")
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
checkpoint_dir, model_parallel_size = self.read_user_inputs() os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
checkpoint_dir = os.path.expanduser(checkpoint_dir) script = pkg_resources.resource_filename(
"llama_toolchain",
assert ( "distribution/install_distribution.sh",
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
), f"{checkpoint_dir} does not exist or it not a directory"
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
self.write_output_yaml(
checkpoint_dir,
model_parallel_size,
yaml_output_path,
) )
dist = None
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
return
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
run_shell_script(script, args.conda_env, " ".join(dist.pip_packages))
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n")
def run_shell_script(script_path, *args):
command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}"
command_list = shlex.split(command_string)
print(f"Running command: {command_list}")
subprocess.run(command_list, check=True, text=True)

View file

@ -25,10 +25,10 @@ from llama_models.sku_list import (
from termcolor import cprint from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints") DEFAULT_CHECKPOINT_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "checkpoints")
class Download(Subcommand): class Download(Subcommand):

View file

@ -13,10 +13,10 @@ from pathlib import Path
import pkg_resources import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DEFAULT_DUMP_DIR from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs") CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
class InferenceConfigure(Subcommand): class InferenceConfigure(Subcommand):

View file

@ -4,12 +4,84 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List from enum import Enum
from typing import Any, Dict, List, Literal, Union
from pydantic import BaseModel from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
class LlamaStackDistribution(BaseModel): @json_schema_type
class AdapterType(Enum):
passthrough_api = "passthrough_api"
python_impl = "python_impl"
not_implemented = "not_implemented"
@json_schema_type
class PassthroughApiAdapterConfig(BaseModel):
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
@json_schema_type
class PythonImplAdapterConfig(BaseModel):
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
module: str = Field(..., description="The name of the module to import")
entrypoint: str = Field(
...,
description="The name of the entrypoint function which creates the implementation for the API",
)
kwargs: Dict[str, Any] = Field(
default_factory=dict, description="kwargs to pass to the entrypoint"
)
@json_schema_type
class NotImplementedAdapterConfig(BaseModel):
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
AdapterConfig = Annotated[
Union[
PassthroughApiAdapterConfig,
NotImplementedAdapterConfig,
PythonImplAdapterConfig,
],
Field(discriminator="type"),
]
class DistributionConfig(BaseModel):
inference: AdapterConfig
safety: AdapterConfig
# configs for each API that the stack provides, e.g.
# agentic_system: AdapterConfig
# post_training: AdapterConfig
class DistributionConfigDefaults(BaseModel):
inference: Dict[str, Any] = Field(
default_factory=dict, description="Default kwargs for the inference adapter"
)
safety: Dict[str, Any] = Field(
default_factory=dict, description="Default kwargs for the safety adapter"
)
class Distribution(BaseModel):
name: str name: str
description: str description: str
@ -17,3 +89,5 @@ class LlamaStackDistribution(BaseModel):
# later, we may have a docker image be the main artifact of # later, we may have a docker image be the main artifact of
# a distribution. # a distribution.
pip_packages: List[str] pip_packages: List[str]
config_defaults: DistributionConfigDefaults

View file

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
echo "Error: conda command not found. Is Conda installed and in your PATH?" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
echo "Conda environment '${env_name}' exists. Checking Python version..."
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
echo "Environment '${env_name}' already has Python ${python_version}. No action needed."
else
echo "Updating environment '${env_name}' to Python ${python_version}..."
conda install -n "${env_name}" python="${python_version}" -y
fi
else
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
conda create -n "${env_name}" python="${python_version}" -y
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
echo "Installing pip dependencies: $pip_dependencies"
conda run -n "${env_name}" pip install $pip_dependencies
fi
}
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <environment_name> <pip_dependencies>" >&2
echo "Example: $0 my_env 'numpy pandas scipy'" >&2
exit 1
fi
env_name="$1"
pip_dependencies="$2"
ensure_conda_env_python310 "$env_name" "$pip_dependencies"

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -6,19 +6,30 @@
from typing import List from typing import List
from .datatypes import LlamaStackDistribution from .datatypes import Distribution, DistributionConfigDefaults
def all_registered_distributions() -> List[LlamaStackDistribution]: def all_registered_distributions() -> List[Distribution]:
return [ return [
LlamaStackDistribution( Distribution(
name="local-source", name="local-source",
description="Use code within `llama_toolchain` itself to run model inference and everything on top", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
pip_packages=[], pip_packages=[],
config_defaults=DistributionConfigDefaults(
inference={
"max_seq_len": 4096,
"max_batch_size": 1,
},
safety={},
),
), ),
LlamaStackDistribution( Distribution(
name="local-ollama", name="local-ollama",
description="Like local-source, but use ollama for running LLM inference", description="Like local-source, but use ollama for running LLM inference",
pip_packages=[], pip_packages=["ollama"],
config_defaults=DistributionConfigDefaults(
inference={},
safety={},
),
), ),
] ]

View file

@ -14,7 +14,7 @@ from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf from omegaconf import OmegaConf
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/") LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
def get_root_directory(): def get_root_directory():
@ -26,7 +26,7 @@ def get_root_directory():
def get_default_config_dir(): def get_default_config_dir():
return os.path.join(DEFAULT_DUMP_DIR, "configs") return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: