Make install + start scripts do proper configuration automatically

This commit is contained in:
Ashwin Bharambe 2024-08-06 21:34:09 -07:00
parent 9e1ca4eeb1
commit e1a7aa4773
5 changed files with 84 additions and 39 deletions

View file

@ -6,7 +6,6 @@
import argparse
import os
import textwrap
import pkg_resources
import yaml
@ -78,32 +77,35 @@ class DistributionInstall(Subcommand):
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
return_code = run_with_pty([script, conda_env, " ".join(deps)])
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
else:
with open(config_file, "w") as f:
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
)
config_file = distrib_dir / "config.yaml"
with open(config_file, "w") as f:
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)
print(
textwrap.dedent(
f"""
Update your conda environment and configure this distribution by running:
conda deactivate && conda activate {conda_env}
llama distribution configure --name {args.name}
"""
)
)

View file

@ -5,8 +5,8 @@
# the root directory of this source tree.
import argparse
import shlex
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
@ -47,9 +47,8 @@ class DistributionStart(Subcommand):
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.server import main as distribution_server_init
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
if not config_file.exists():
@ -67,16 +66,17 @@ class DistributionStart(Subcommand):
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
python_exe = run_command(shlex.split("which python"))
# simple check, unfortunate
if conda_env not in python_exe:
if not conda_env:
raise ValueError(
f"Please re-run start after activating the `{conda_env}` conda environment first"
f"Could not find Conda environment for distribution `{args.name}`"
)
distribution_server_init(
config_file,
args.port,
disable_ipv6=args.disable_ipv6,
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
run_with_pty(args)

View file

@ -66,13 +66,20 @@ ensure_conda_env_python310() {
fi
}
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <environment_name> <pip_dependencies>" >&2
echo "Example: $0 my_env 'numpy pandas scipy'" >&2
if [ "$#" -ne 3 ]; then
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
echo "Example: $0 my_env local-inline 'numpy pandas scipy'" >&2
exit 1
fi
env_name="$1"
pip_dependencies="$2"
distribution_name="$2"
pip_dependencies="$3"
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"

View file

@ -73,7 +73,6 @@ def available_distribution_specs() -> List[DistributionSpec]:
additional_pip_packages=[
"python-dotenv",
"blobfile",
"codeshield",
"fairscale",
"fastapi",
"fire",
@ -82,6 +81,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
"json-strong-typing",
"pydantic==1.10.13",
"pydantic_core==2.18.2",
"tiktoken",
"uvicorn",
],
provider_specs={x: remote_spec(x) for x in providers},

View file

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
fi
env_name="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.distribution.server "$@"