diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index 8ce3c04b9..68d42938d 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -6,7 +6,6 @@ import argparse import os -import textwrap import pkg_resources import yaml @@ -78,32 +77,35 @@ class DistributionInstall(Subcommand): print(f"Using {args.name} as the Conda environment for this distribution") conda_env = args.conda_env or args.name - return_code = run_with_pty([script, conda_env, " ".join(deps)]) + + config_file = distrib_dir / "config.yaml" + if config_file.exists(): + c = DistributionConfig(**yaml.safe_load(config_file.read_text())) + if c.spec != dist.spec_id: + self.parser.error( + f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`" + ) + return + if c.conda_env != conda_env: + self.parser.error( + f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`" + ) + return + else: + with open(config_file, "w") as f: + c = DistributionConfig( + spec=dist.spec_id, + name=args.name, + conda_env=conda_env, + ) + f.write(yaml.dump(c.dict(), sort_keys=False)) + + return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)]) assert return_code == 0, cprint( f"Failed to install distribution {dist.spec_id}", color="red" ) - - config_file = distrib_dir / "config.yaml" - with open(config_file, "w") as f: - c = DistributionConfig( - spec=dist.spec_id, - name=args.name, - conda_env=conda_env, - ) - f.write(yaml.dump(c.dict(), sort_keys=False)) - cprint( f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", color="green", ) - print( - textwrap.dedent( - f""" - Update your conda environment and configure this distribution by running: - - conda deactivate && conda activate {conda_env} - llama distribution configure --name {args.name} - """ - ) - ) diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py index c106a237c..8620550db 100644 --- a/llama_toolchain/cli/distribution/start.py +++ b/llama_toolchain/cli/distribution/start.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import argparse -import shlex +import pkg_resources import yaml from llama_toolchain.cli.subcommand import Subcommand @@ -47,9 +47,8 @@ class DistributionStart(Subcommand): ) def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.common.exec import run_command + from llama_toolchain.common.exec import run_with_pty from llama_toolchain.distribution.registry import resolve_distribution_spec - from llama_toolchain.distribution.server import main as distribution_server_init config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" if not config_file.exists(): @@ -67,16 +66,17 @@ class DistributionStart(Subcommand): raise ValueError(f"Could not find any registered spec `{config['spec']}`") conda_env = config["conda_env"] - - python_exe = run_command(shlex.split("which python")) - # simple check, unfortunate - if conda_env not in python_exe: + if not conda_env: raise ValueError( - f"Please re-run start after activating the `{conda_env}` conda environment first" + f"Could not find Conda environment for distribution `{args.name}`" ) - distribution_server_init( - config_file, - args.port, - disable_ipv6=args.disable_ipv6, + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/start_distribution.sh", ) + args = [script, conda_env, config_file, "--port", str(args.port)] + ( + ["--disable-ipv6"] if args.disable_ipv6 else [] + ) + + run_with_pty(args) diff --git a/llama_toolchain/distribution/install_distribution.sh b/llama_toolchain/distribution/install_distribution.sh index 975c205d0..7bfb5e27a 100755 --- a/llama_toolchain/distribution/install_distribution.sh +++ b/llama_toolchain/distribution/install_distribution.sh @@ -66,13 +66,20 @@ ensure_conda_env_python310() { fi } -if [ "$#" -ne 2 ]; then - echo "Usage: $0 " >&2 - echo "Example: $0 my_env 'numpy pandas scipy'" >&2 +if [ "$#" -ne 3 ]; then + echo "Usage: $0 " >&2 + echo "Example: $0 my_env local-inline 'numpy pandas scipy'" >&2 exit 1 fi env_name="$1" -pip_dependencies="$2" +distribution_name="$2" +pip_dependencies="$3" ensure_conda_env_python310 "$env_name" "$pip_dependencies" + +eval "$(conda shell.bash hook)" +conda deactivate && conda activate "$env_name" + +python_interp=$(conda run -n "$env_name" which python) +$python_interp -m llama_toolchain.cli.llama distribution configure --name "$distribution_name" diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 1f3021599..a89bd9b7d 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -73,7 +73,6 @@ def available_distribution_specs() -> List[DistributionSpec]: additional_pip_packages=[ "python-dotenv", "blobfile", - "codeshield", "fairscale", "fastapi", "fire", @@ -82,6 +81,7 @@ def available_distribution_specs() -> List[DistributionSpec]: "json-strong-typing", "pydantic==1.10.13", "pydantic_core==2.18.2", + "tiktoken", "uvicorn", ], provider_specs={x: remote_spec(x) for x in providers}, diff --git a/llama_toolchain/distribution/start_distribution.sh b/llama_toolchain/distribution/start_distribution.sh new file mode 100755 index 000000000..271919676 --- /dev/null +++ b/llama_toolchain/distribution/start_distribution.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +# Define color codes +RED='\033[0;31m' +NC='\033[0m' # No Color + +error_handler() { + echo "Error occurred in script at line: ${1}" >&2 + exit 1 +} + +# Set up the error trap +trap 'error_handler ${LINENO}' ERR + +if [ $# -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi + + +env_name="$1" +shift + +eval "$(conda shell.bash hook)" +conda deactivate && conda activate "$env_name" + +python_interp=$(conda run -n "$env_name" which python) +$python_interp -m llama_toolchain.distribution.server "$@"