All the new CLI for api + stack work

This commit is contained in:
Ashwin Bharambe 2024-08-28 15:52:49 -07:00
parent fd3b65b718
commit 197f768636
16 changed files with 459 additions and 486 deletions

View file

@ -6,11 +6,13 @@
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackStart(Subcommand):
@ -18,19 +20,18 @@ class StackStart(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
prog="llama stack start",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
self.parser.set_defaults(func=self._run_stack_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"yaml_config",
type=str,
help="Name of the distribution to start",
required=True,
help="Yaml config containing the API build configuration",
)
self.parser.add_argument(
"--port",
@ -45,37 +46,45 @@ class StackStart(Subcommand):
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
def _run_stack_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
if not conda_env:
raise ValueError(
f"Could not find Conda environment for distribution `{args.name}`"
if not config.distribution_id:
# this is technically not necessary. everything else continues to work,
# but maybe we want to be very clear for the users
self.parser.error(
"No distribution_id found. Did you want to start a provider?"
)
return
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_with_pty(args)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)