mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
More progress towards llama distribution install
This commit is contained in:
parent
5a583cf16e
commit
dac2b5a1ed
11 changed files with 298 additions and 75 deletions
92
llama_toolchain/cli/distribution/configure.py
Normal file
92
llama_toolchain/cli/distribution/configure.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import all_registered_distributions
|
||||
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
|
||||
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
|
||||
|
||||
|
||||
class DistributionConfigure(Subcommand):
|
||||
"""Llama cli for configuring llama toolchain configs"""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"configure",
|
||||
prog="llama distribution configure",
|
||||
description="configure a llama stack distribution",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
distribs = all_registered_distributions()
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Mame of the distribution to configure",
|
||||
default="local-source",
|
||||
choices=[d.name for d in distribs],
|
||||
)
|
||||
|
||||
def read_user_inputs(self):
|
||||
checkpoint_dir = input(
|
||||
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
|
||||
)
|
||||
model_parallel_size = input(
|
||||
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
|
||||
)
|
||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
|
||||
1,
|
||||
8,
|
||||
}, "model parallel size must be 1 or 8"
|
||||
|
||||
return checkpoint_dir, model_parallel_size
|
||||
|
||||
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
|
||||
default_conf_path = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "data/default_distribution_config.yaml"
|
||||
)
|
||||
with open(default_conf_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
yaml_content = yaml_content.format(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model_parallel_size=model_parallel_size,
|
||||
)
|
||||
|
||||
with open(yaml_output_path, "w") as yaml_file:
|
||||
yaml_file.write(yaml_content.strip())
|
||||
|
||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||
|
||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
||||
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
||||
|
||||
assert (
|
||||
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
|
||||
), f"{checkpoint_dir} does not exist or it not a directory"
|
||||
|
||||
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
||||
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
|
||||
|
||||
self.write_output_yaml(
|
||||
checkpoint_dir,
|
||||
model_parallel_size,
|
||||
yaml_output_path,
|
||||
)
|
|
@ -7,6 +7,7 @@
|
|||
import argparse
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from .configure import DistributionConfigure
|
||||
from .create import DistributionCreate
|
||||
from .install import DistributionInstall
|
||||
from .list import DistributionList
|
||||
|
@ -28,3 +29,4 @@ class DistributionParser(Subcommand):
|
|||
DistributionList.create(subparsers)
|
||||
DistributionInstall.create(subparsers)
|
||||
DistributionCreate.create(subparsers)
|
||||
DistributionConfigure.create(subparsers)
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -13,10 +15,12 @@ import pkg_resources
|
|||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import all_registered_distributions
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
|
||||
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||
DISTRIBS_BASE_DIR = Path(LLAMA_STACK_CONFIG_DIR) / "distributions"
|
||||
|
||||
DISTRIBS = all_registered_distributions()
|
||||
|
||||
|
||||
class DistributionInstall(Subcommand):
|
||||
|
@ -34,59 +38,45 @@ class DistributionInstall(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_distribution_install_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
distribs = all_registered_distributions()
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Mame of the distribution to install",
|
||||
default="local-source",
|
||||
choices=[d.name for d in distribs],
|
||||
help="Name of the distribution to install -- (try local-ollama)",
|
||||
required=True,
|
||||
choices=[d.name for d in DISTRIBS],
|
||||
)
|
||||
|
||||
def read_user_inputs(self):
|
||||
checkpoint_dir = input(
|
||||
"Enter the checkpoint directory for the model (e.g., ~/.llama/checkpoints/Meta-Llama-3-8B/): "
|
||||
self.parser.add_argument(
|
||||
"--conda-env",
|
||||
type=str,
|
||||
help="Specify the name of the conda environment you would like to create or update",
|
||||
required=True,
|
||||
)
|
||||
model_parallel_size = input(
|
||||
"Enter model parallel size (e.g., 1 for 8B / 8 for 70B and 405B): "
|
||||
)
|
||||
assert model_parallel_size.isdigit() and int(model_parallel_size) in {
|
||||
1,
|
||||
8,
|
||||
}, "model parallel size must be 1 or 8"
|
||||
|
||||
return checkpoint_dir, model_parallel_size
|
||||
|
||||
def write_output_yaml(self, checkpoint_dir, model_parallel_size, yaml_output_path):
|
||||
default_conf_path = pkg_resources.resource_filename(
|
||||
"llama_toolchain", "data/default_distribution_config.yaml"
|
||||
)
|
||||
with open(default_conf_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
yaml_content = yaml_content.format(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model_parallel_size=model_parallel_size,
|
||||
)
|
||||
|
||||
with open(yaml_output_path, "w") as yaml_file:
|
||||
yaml_file.write(yaml_content.strip())
|
||||
|
||||
print(f"YAML configuration has been written to {yaml_output_path}")
|
||||
|
||||
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
|
||||
checkpoint_dir, model_parallel_size = self.read_user_inputs()
|
||||
checkpoint_dir = os.path.expanduser(checkpoint_dir)
|
||||
|
||||
assert (
|
||||
Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir()
|
||||
), f"{checkpoint_dir} does not exist or it not a directory"
|
||||
|
||||
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
|
||||
yaml_output_path = Path(CONFIGS_BASE_DIR) / "distribution.yaml"
|
||||
|
||||
self.write_output_yaml(
|
||||
checkpoint_dir,
|
||||
model_parallel_size,
|
||||
yaml_output_path,
|
||||
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_toolchain",
|
||||
"distribution/install_distribution.sh",
|
||||
)
|
||||
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
||||
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
|
||||
run_shell_script(script, args.conda_env, " ".join(dist.pip_packages))
|
||||
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
|
||||
f.write(f"{args.conda_env}\n")
|
||||
|
||||
|
||||
def run_shell_script(script_path, *args):
|
||||
command_string = f"{script_path} {' '.join(shlex.quote(str(arg)) for arg in args)}"
|
||||
command_list = shlex.split(command_string)
|
||||
print(f"Running command: {command_list}")
|
||||
subprocess.run(command_list, check=True, text=True)
|
||||
|
|
|
@ -25,10 +25,10 @@ from llama_models.sku_list import (
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR = os.path.join(DEFAULT_DUMP_DIR, "checkpoints")
|
||||
DEFAULT_CHECKPOINT_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "checkpoints")
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
|
|
|
@ -13,10 +13,10 @@ from pathlib import Path
|
|||
import pkg_resources
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.utils import DEFAULT_DUMP_DIR
|
||||
from llama_toolchain.utils import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
|
||||
CONFIGS_BASE_DIR = os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||
CONFIGS_BASE_DIR = os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
|
||||
|
||||
|
||||
class InferenceConfigure(Subcommand):
|
||||
|
|
|
@ -4,12 +4,84 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class LlamaStackDistribution(BaseModel):
|
||||
@json_schema_type
|
||||
class AdapterType(Enum):
|
||||
passthrough_api = "passthrough_api"
|
||||
python_impl = "python_impl"
|
||||
not_implemented = "not_implemented"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughApiAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
headers: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Headers (e.g., authorization) to send with the request",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PythonImplAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
module: str = Field(..., description="The name of the module to import")
|
||||
entrypoint: str = Field(
|
||||
...,
|
||||
description="The name of the entrypoint function which creates the implementation for the API",
|
||||
)
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="kwargs to pass to the entrypoint"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NotImplementedAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
|
||||
|
||||
|
||||
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
|
||||
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
|
||||
AdapterConfig = Annotated[
|
||||
Union[
|
||||
PassthroughApiAdapterConfig,
|
||||
NotImplementedAdapterConfig,
|
||||
PythonImplAdapterConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class DistributionConfig(BaseModel):
|
||||
inference: AdapterConfig
|
||||
safety: AdapterConfig
|
||||
|
||||
# configs for each API that the stack provides, e.g.
|
||||
# agentic_system: AdapterConfig
|
||||
# post_training: AdapterConfig
|
||||
|
||||
|
||||
class DistributionConfigDefaults(BaseModel):
|
||||
inference: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default kwargs for the inference adapter"
|
||||
)
|
||||
safety: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default kwargs for the safety adapter"
|
||||
)
|
||||
|
||||
|
||||
class Distribution(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
@ -17,3 +89,5 @@ class LlamaStackDistribution(BaseModel):
|
|||
# later, we may have a docker image be the main artifact of
|
||||
# a distribution.
|
||||
pip_packages: List[str]
|
||||
|
||||
config_defaults: DistributionConfigDefaults
|
||||
|
|
64
llama_toolchain/distribution/install_distribution.sh
Executable file
64
llama_toolchain/distribution/install_distribution.sh
Executable file
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Set up the error trap
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local python_version="3.10"
|
||||
|
||||
# Check if conda command is available
|
||||
if ! command -v conda &>/dev/null; then
|
||||
echo "Error: conda command not found. Is Conda installed and in your PATH?" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if the environment exists
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
echo "Conda environment '${env_name}' exists. Checking Python version..."
|
||||
|
||||
# Check Python version in the environment
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
echo "Environment '${env_name}' already has Python ${python_version}. No action needed."
|
||||
else
|
||||
echo "Updating environment '${env_name}' to Python ${python_version}..."
|
||||
conda install -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
else
|
||||
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
if [ -n "$pip_dependencies" ]; then
|
||||
echo "Installing pip dependencies: $pip_dependencies"
|
||||
conda run -n "${env_name}" pip install $pip_dependencies
|
||||
fi
|
||||
}
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <environment_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 my_env 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
env_name="$1"
|
||||
pip_dependencies="$2"
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -6,19 +6,30 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from .datatypes import LlamaStackDistribution
|
||||
from .datatypes import Distribution, DistributionConfigDefaults
|
||||
|
||||
|
||||
def all_registered_distributions() -> List[LlamaStackDistribution]:
|
||||
def all_registered_distributions() -> List[Distribution]:
|
||||
return [
|
||||
LlamaStackDistribution(
|
||||
Distribution(
|
||||
name="local-source",
|
||||
description="Use code within `llama_toolchain` itself to run model inference and everything on top",
|
||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||
pip_packages=[],
|
||||
config_defaults=DistributionConfigDefaults(
|
||||
inference={
|
||||
"max_seq_len": 4096,
|
||||
"max_batch_size": 1,
|
||||
},
|
||||
safety={},
|
||||
),
|
||||
),
|
||||
LlamaStackDistribution(
|
||||
Distribution(
|
||||
name="local-ollama",
|
||||
description="Like local-source, but use ollama for running LLM inference",
|
||||
pip_packages=[],
|
||||
pip_packages=["ollama"],
|
||||
config_defaults=DistributionConfigDefaults(
|
||||
inference={},
|
||||
safety={},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ from hydra.core.global_hydra import GlobalHydra
|
|||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
DEFAULT_DUMP_DIR = os.path.expanduser("~/.llama/")
|
||||
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
|
||||
|
||||
|
||||
def get_root_directory():
|
||||
|
@ -26,7 +26,7 @@ def get_root_directory():
|
|||
|
||||
|
||||
def get_default_config_dir():
|
||||
return os.path.join(DEFAULT_DUMP_DIR, "configs")
|
||||
return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
|
||||
|
||||
|
||||
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue