All the new CLI for api + stack work

This commit is contained in:
Ashwin Bharambe 2024-08-28 15:52:49 -07:00
parent fd3b65b718
commit 197f768636
16 changed files with 459 additions and 486 deletions

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
choices=allowed_ids,
help="Distribution to build (one of: {})".format(allowed_ids),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="container",
choices=[v.value for v in BuildType],
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
api_inputs = []
for api, provider_id in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
dependencies={},
)
)
build_package(
api_inputs,
build_type=BuildType(args.type),
name=args.name,
distribution_id=args.distribution,
docker_image=dist.docker_image,
)

View file

@ -6,13 +6,14 @@
import argparse
import json
import shlex
from pathlib import Path
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackConfigure(Subcommand):
@ -22,85 +23,61 @@ class StackConfigure(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama distribution configure",
prog="llama stack configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"--build-name",
type=str,
help="Name of the distribution to configure",
help="Name of the stack build to configure",
required=True,
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.build_name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = BUILDS_BASE_DIR / "stack" / name
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = StackConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.spec)
if dist is None:
raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config)
configure_llama_distribution(config_file)
def configure_llama_distribution(dist: "Stack", config: "StackConfig"):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
def configure_llama_distribution(config_file: Path) -> None:
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.dynamic import instantiate_class_type
from llama_toolchain.distribution.configure import configure_api_providers
from llama_toolchain.distribution.registry import resolve_distribution_spec
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config.conda_env
if conda_env not in python_exe:
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.distribution_id)
if dist is None:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
f"Could not find any registered distribution `{config.distribution_id}`"
)
if config.providers:
cprint(
f"Configuration already exists for {config.name}. Will overwrite...",
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
"yellow",
attrs=["bold"],
)
for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
provider_config = prompt_for_config(
config_type,
(
config_type(**config.providers[api.value])
if api.value in config.providers
else None
),
)
print("")
config.providers = configure_api_providers(config.providers)
config.providers[api.value] = {
"provider_id": provider_spec.provider_id,
**provider_config.dict(),
}
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")
print(f"YAML configuration has been written to {config_file}")

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
class StackCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"create",
prog="llama distribution create",
description="create a Llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each Api the user wants to support, we should
# get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution_spec(args.name)
if dist is not None:
self.parser.error(f"Stack with name {args.name} already exists")
return
raise NotImplementedError()

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import pkg_resources
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class StackInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
prog="llama distribution install",
description="Install a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument(
"--spec",
type=str,
help="Stack spec to install (try local-ollama)",
required=True,
choices=[d.spec_id for d in available_distribution_specs()],
)
self.parser.add_argument(
"--name",
type=str,
help="What should the installation be called locally?",
required=True,
)
self.parser.add_argument(
"--conda-env",
type=str,
help="conda env in which this distribution will run (default = distribution name)",
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/install_distribution.sh",
)
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
return
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = StackConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
else:
with open(config_file, "w") as f:
c = StackConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
)
cprint(
f"Stack `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)

View file

@ -14,9 +14,9 @@ class StackList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama distribution list",
description="Show available llama stack distributions",
"list-distributions",
prog="llama stack list-distributions",
description="Show available Llama Stack Distributions",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
@ -31,17 +31,17 @@ class StackList(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Spec ID",
"ProviderSpecs",
"Distribution ID",
"Providers",
"Description",
]
rows = []
for spec in available_distribution_specs():
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
providers = {k.value: v for k, v in spec.providers.items()}
rows.append(
[
spec.spec_id,
spec.distribution_id,
json.dumps(providers, indent=2),
spec.description,
]

View file

@ -8,9 +8,8 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .create import StackCreate
from .install import StackInstall
from .list import StackList
from .start import StackStart
@ -19,16 +18,15 @@ class StackParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"distribution",
prog="llama distribution",
description="Operate on llama stack distributions",
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
)
subparsers = self.parser.add_subparsers(title="distribution_subcommands")
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackList.create(subparsers)
StackInstall.create(subparsers)
StackCreate.create(subparsers)
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackList.create(subparsers)
StackStart.create(subparsers)

View file

@ -6,11 +6,13 @@
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackStart(Subcommand):
@ -18,19 +20,18 @@ class StackStart(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
prog="llama stack start",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
self.parser.set_defaults(func=self._run_stack_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"yaml_config",
type=str,
help="Name of the distribution to start",
required=True,
help="Yaml config containing the API build configuration",
)
self.parser.add_argument(
"--port",
@ -45,37 +46,45 @@ class StackStart(Subcommand):
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
def _run_stack_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
if not conda_env:
raise ValueError(
f"Could not find Conda environment for distribution `{args.name}`"
if not config.distribution_id:
# this is technically not necessary. everything else continues to work,
# but maybe we want to be very clear for the users
self.parser.error(
"No distribution_id found. Did you want to start a provider?"
)
return
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_with_pty(args)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)