All the new CLI for api + stack work

This commit is contained in:
Ashwin Bharambe 2024-08-28 15:52:49 -07:00
parent fd3b65b718
commit 197f768636
16 changed files with 459 additions and 486 deletions

View file

@ -5,52 +5,12 @@
# the root directory of this source tree.
import argparse
import json
import os
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
import pkg_resources
import yaml
from termcolor import cprint
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
pip_packages = provider.pip_packages
for dep in dependencies.values():
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES,
)
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
@ -89,6 +49,9 @@ class ApiBuild(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
@ -123,101 +86,20 @@ class ApiBuild(Subcommand):
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.common.serialize import EnumEncoder
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
all_providers = api_providers()
api = Api(args.api)
assert api in all_providers
providers = all_providers[api]
if args.provider not in providers:
self.parser.error(
f"Provider `{args.provider}` is not available for API `{api}`"
)
return
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{args.name}"
else:
package_name = f"env-{args.provider}-{args.name}"
package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
# get these names straight. too confusing.
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
package_file = build_dir / f"{package_name}.yaml"
stub_config = {
api.value: {
"provider_id": args.provider,
},
**provider_deps,
}
# properly handle the case where package exists but has
# inconsistent configuration for the providers. if possible,
# we don't want to overwrite the existing configuration.
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.docker_image = (
package_name if args.type == BuildType.container.value else None
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
c.conda_env = package_name if args.type == BuildType.conda_env.value else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if args.type == BuildType.container.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
args.api,
package_name,
dependencies.docker_image or "python:3.10-slim",
" ".join(dependencies.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
args.api,
package_name,
" ".join(dependencies.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
api_input = ApiInput(
api=Api(args.api),
provider=args.provider,
dependencies=parse_dependencies(args.dependencies or "", self.parser),
)
build_package(
[api_input],
build_type=BuildType(args.type),
name=args.name,
)

View file

@ -14,7 +14,6 @@ import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class ApiConfigure(Subcommand):
@ -41,14 +40,14 @@ class ApiConfigure(Subcommand):
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
"--name",
"--build-name",
type=str,
help="Name of the provider build to fully configure",
required=True,
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.name
name = args.build_name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = BUILDS_BASE_DIR / args.api / name
@ -62,48 +61,14 @@ class ApiConfigure(Subcommand):
def configure_llama_provider(config_file: Path) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.distribution.dynamic import instantiate_class_type
from llama_toolchain.distribution.configure import configure_api_providers
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
all_providers = api_providers()
config.providers = configure_api_providers(config.providers)
provider_configs = {}
for api, stub_config in config.providers.items():
providers = all_providers[Api(api)]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_spec = providers[provider_id]
cprint(
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
)
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api] = {
"provider_id": provider_id,
**provider_config.dict(),
}
config.providers = provider_configs
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
choices=allowed_ids,
help="Distribution to build (one of: {})".format(allowed_ids),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="container",
choices=[v.value for v in BuildType],
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
api_inputs = []
for api, provider_id in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
dependencies={},
)
)
build_package(
api_inputs,
build_type=BuildType(args.type),
name=args.name,
distribution_id=args.distribution,
docker_image=dist.docker_image,
)

View file

@ -6,13 +6,14 @@
import argparse
import json
import shlex
from pathlib import Path
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackConfigure(Subcommand):
@ -22,85 +23,61 @@ class StackConfigure(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama distribution configure",
prog="llama stack configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"--build-name",
type=str,
help="Name of the distribution to configure",
help="Name of the stack build to configure",
required=True,
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.build_name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = BUILDS_BASE_DIR / "stack" / name
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = StackConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.spec)
if dist is None:
raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config)
configure_llama_distribution(config_file)
def configure_llama_distribution(dist: "Stack", config: "StackConfig"):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
def configure_llama_distribution(config_file: Path) -> None:
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.dynamic import instantiate_class_type
from llama_toolchain.distribution.configure import configure_api_providers
from llama_toolchain.distribution.registry import resolve_distribution_spec
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config.conda_env
if conda_env not in python_exe:
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.distribution_id)
if dist is None:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
f"Could not find any registered distribution `{config.distribution_id}`"
)
if config.providers:
cprint(
f"Configuration already exists for {config.name}. Will overwrite...",
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
"yellow",
attrs=["bold"],
)
for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
provider_config = prompt_for_config(
config_type,
(
config_type(**config.providers[api.value])
if api.value in config.providers
else None
),
)
print("")
config.providers = configure_api_providers(config.providers)
config.providers[api.value] = {
"provider_id": provider_spec.provider_id,
**provider_config.dict(),
}
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")
print(f"YAML configuration has been written to {config_file}")

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
class StackCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"create",
prog="llama distribution create",
description="create a Llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each Api the user wants to support, we should
# get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution_spec(args.name)
if dist is not None:
self.parser.error(f"Stack with name {args.name} already exists")
return
raise NotImplementedError()

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import pkg_resources
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class StackInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
prog="llama distribution install",
description="Install a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument(
"--spec",
type=str,
help="Stack spec to install (try local-ollama)",
required=True,
choices=[d.spec_id for d in available_distribution_specs()],
)
self.parser.add_argument(
"--name",
type=str,
help="What should the installation be called locally?",
required=True,
)
self.parser.add_argument(
"--conda-env",
type=str,
help="conda env in which this distribution will run (default = distribution name)",
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/install_distribution.sh",
)
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
return
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = StackConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
else:
with open(config_file, "w") as f:
c = StackConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
)
cprint(
f"Stack `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)

View file

@ -14,9 +14,9 @@ class StackList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama distribution list",
description="Show available llama stack distributions",
"list-distributions",
prog="llama stack list-distributions",
description="Show available Llama Stack Distributions",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
@ -31,17 +31,17 @@ class StackList(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Spec ID",
"ProviderSpecs",
"Distribution ID",
"Providers",
"Description",
]
rows = []
for spec in available_distribution_specs():
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
providers = {k.value: v for k, v in spec.providers.items()}
rows.append(
[
spec.spec_id,
spec.distribution_id,
json.dumps(providers, indent=2),
spec.description,
]

View file

@ -8,9 +8,8 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .create import StackCreate
from .install import StackInstall
from .list import StackList
from .start import StackStart
@ -19,16 +18,15 @@ class StackParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"distribution",
prog="llama distribution",
description="Operate on llama stack distributions",
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
)
subparsers = self.parser.add_subparsers(title="distribution_subcommands")
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackList.create(subparsers)
StackInstall.create(subparsers)
StackCreate.create(subparsers)
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackList.create(subparsers)
StackStart.create(subparsers)

View file

@ -6,11 +6,13 @@
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackStart(Subcommand):
@ -18,19 +20,18 @@ class StackStart(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
prog="llama stack start",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
self.parser.set_defaults(func=self._run_stack_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"yaml_config",
type=str,
help="Name of the distribution to start",
required=True,
help="Yaml config containing the API build configuration",
)
self.parser.add_argument(
"--port",
@ -45,37 +46,45 @@ class StackStart(Subcommand):
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
def _run_stack_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
if not conda_env:
raise ValueError(
f"Could not find Conda environment for distribution `{args.name}`"
if not config.distribution_id:
# this is technically not necessary. everything else continues to work,
# but maybe we want to be very clear for the users
self.parser.error(
"No distribution_id found. Did you want to start a provider?"
)
return
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_with_pty(args)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)