All the new CLI for api + stack work

This commit is contained in:
Ashwin Bharambe 2024-08-28 15:52:49 -07:00
parent fd3b65b718
commit 197f768636
16 changed files with 459 additions and 486 deletions

View file

@ -16,13 +16,14 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from .api import * # noqa: F403
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return AgenticSystemClient(base_url)
async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgenticSystemClient(config.url)
def encodable_dict(d: BaseModel):

View file

@ -5,52 +5,12 @@
# the root directory of this source tree.
import argparse
import json
import os
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
import pkg_resources
import yaml
from termcolor import cprint
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
pip_packages = provider.pip_packages
for dep in dependencies.values():
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES,
)
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
@ -89,6 +49,9 @@ class ApiBuild(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
@ -123,101 +86,20 @@ class ApiBuild(Subcommand):
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.common.serialize import EnumEncoder
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
all_providers = api_providers()
api = Api(args.api)
assert api in all_providers
providers = all_providers[api]
if args.provider not in providers:
self.parser.error(
f"Provider `{args.provider}` is not available for API `{api}`"
)
return
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{args.name}"
else:
package_name = f"env-{args.provider}-{args.name}"
package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
# get these names straight. too confusing.
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
package_file = build_dir / f"{package_name}.yaml"
stub_config = {
api.value: {
"provider_id": args.provider,
},
**provider_deps,
}
# properly handle the case where package exists but has
# inconsistent configuration for the providers. if possible,
# we don't want to overwrite the existing configuration.
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.docker_image = (
package_name if args.type == BuildType.container.value else None
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
c.conda_env = package_name if args.type == BuildType.conda_env.value else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if args.type == BuildType.container.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
args.api,
package_name,
dependencies.docker_image or "python:3.10-slim",
" ".join(dependencies.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
args.api,
package_name,
" ".join(dependencies.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
api_input = ApiInput(
api=Api(args.api),
provider=args.provider,
dependencies=parse_dependencies(args.dependencies or "", self.parser),
)
build_package(
[api_input],
build_type=BuildType(args.type),
name=args.name,
)

View file

@ -14,7 +14,6 @@ import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class ApiConfigure(Subcommand):
@ -41,14 +40,14 @@ class ApiConfigure(Subcommand):
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
"--name",
"--build-name",
type=str,
help="Name of the provider build to fully configure",
required=True,
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.name
name = args.build_name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = BUILDS_BASE_DIR / args.api / name
@ -62,48 +61,14 @@ class ApiConfigure(Subcommand):
def configure_llama_provider(config_file: Path) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.distribution.dynamic import instantiate_class_type
from llama_toolchain.distribution.configure import configure_api_providers
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
all_providers = api_providers()
config.providers = configure_api_providers(config.providers)
provider_configs = {}
for api, stub_config in config.providers.items():
providers = all_providers[Api(api)]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_spec = providers[provider_id]
cprint(
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
)
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api] = {
"provider_id": provider_id,
**provider_config.dict(),
}
config.providers = provider_configs
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"distribution",
type=str,
choices=allowed_ids,
help="Distribution to build (one of: {})".format(allowed_ids),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="container",
choices=[v.value for v in BuildType],
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
api_inputs = []
for api, provider_id in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
dependencies={},
)
)
build_package(
api_inputs,
build_type=BuildType(args.type),
name=args.name,
distribution_id=args.distribution,
docker_image=dist.docker_image,
)

View file

@ -6,13 +6,14 @@
import argparse
import json
import shlex
from pathlib import Path
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackConfigure(Subcommand):
@ -22,85 +23,61 @@ class StackConfigure(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama distribution configure",
prog="llama stack configure",
description="configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_configure_cmd)
self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"--build-name",
type=str,
help="Name of the distribution to configure",
help="Name of the stack build to configure",
required=True,
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.registry import resolve_distribution_spec
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.build_name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = BUILDS_BASE_DIR / "stack" / name
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = StackConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.spec)
if dist is None:
raise ValueError(f"Could not find any registered spec `{config.spec}`")
configure_llama_distribution(dist, config)
configure_llama_distribution(config_file)
def configure_llama_distribution(dist: "Stack", config: "StackConfig"):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
def configure_llama_distribution(config_file: Path) -> None:
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.dynamic import instantiate_class_type
from llama_toolchain.distribution.configure import configure_api_providers
from llama_toolchain.distribution.registry import resolve_distribution_spec
python_exe = run_command(shlex.split("which python"))
# simple check
conda_env = config.conda_env
if conda_env not in python_exe:
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config.distribution_id)
if dist is None:
raise ValueError(
f"Please re-run configure by activating the `{conda_env}` conda environment"
f"Could not find any registered distribution `{config.distribution_id}`"
)
if config.providers:
cprint(
f"Configuration already exists for {config.name}. Will overwrite...",
f"Configuration already exists for {config.distribution_id}. Will overwrite...",
"yellow",
attrs=["bold"],
)
for api, provider_spec in dist.provider_specs.items():
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
provider_config = prompt_for_config(
config_type,
(
config_type(**config.providers[api.value])
if api.value in config.providers
else None
),
)
print("")
config.providers = configure_api_providers(config.providers)
config.providers[api.value] = {
"provider_id": provider_spec.provider_id,
**provider_config.dict(),
}
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml"
with open(config_path, "w") as fp:
dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(dist_config, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")
print(f"YAML configuration has been written to {config_file}")

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
class StackCreate(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"create",
prog="llama distribution create",
description="create a Llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each Api the user wants to support, we should
# get the list of available providers, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.distribution.registry import resolve_distribution_spec
dist = resolve_distribution_spec(args.name)
if dist is not None:
self.parser.error(f"Stack with name {args.name} already exists")
return
raise NotImplementedError()

View file

@ -1,111 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import pkg_resources
import yaml
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
class StackInstall(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
prog="llama distribution install",
description="Install a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_install_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument(
"--spec",
type=str,
help="Stack spec to install (try local-ollama)",
required=True,
choices=[d.spec_id for d in available_distribution_specs()],
)
self.parser.add_argument(
"--name",
type=str,
help="What should the installation be called locally?",
required=True,
)
self.parser.add_argument(
"--conda-env",
type=str,
help="conda env in which this distribution will run (default = distribution name)",
)
def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import StackConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/install_distribution.sh",
)
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
return
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = StackConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
else:
with open(config_file, "w") as f:
c = StackConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
)
cprint(
f"Stack `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
color="green",
)

View file

@ -14,9 +14,9 @@ class StackList(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama distribution list",
description="Show available llama stack distributions",
"list-distributions",
prog="llama stack list-distributions",
description="Show available Llama Stack Distributions",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
@ -31,17 +31,17 @@ class StackList(Subcommand):
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Spec ID",
"ProviderSpecs",
"Distribution ID",
"Providers",
"Description",
]
rows = []
for spec in available_distribution_specs():
providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()}
providers = {k.value: v for k, v in spec.providers.items()}
rows.append(
[
spec.spec_id,
spec.distribution_id,
json.dumps(providers, indent=2),
spec.description,
]

View file

@ -8,9 +8,8 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import StackBuild
from .configure import StackConfigure
from .create import StackCreate
from .install import StackInstall
from .list import StackList
from .start import StackStart
@ -19,16 +18,15 @@ class StackParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"distribution",
prog="llama distribution",
description="Operate on llama stack distributions",
"stack",
prog="llama stack",
description="Operations for the Llama Stack / Distributions",
)
subparsers = self.parser.add_subparsers(title="distribution_subcommands")
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackList.create(subparsers)
StackInstall.create(subparsers)
StackCreate.create(subparsers)
StackBuild.create(subparsers)
StackConfigure.create(subparsers)
StackList.create(subparsers)
StackStart.create(subparsers)

View file

@ -6,11 +6,13 @@
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class StackStart(Subcommand):
@ -18,19 +20,18 @@ class StackStart(Subcommand):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
prog="llama stack start",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
self.parser.set_defaults(func=self._run_stack_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
"yaml_config",
type=str,
help="Name of the distribution to start",
required=True,
help="Yaml config containing the API build configuration",
)
self.parser.add_argument(
"--port",
@ -45,37 +46,45 @@ class StackStart(Subcommand):
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
def _run_stack_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.registry import resolve_distribution_spec
config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml"
config_file = Path(args.yaml_config)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama distribution install` first"
f"Could not find {config_file}. Please run `llama stack build` first"
)
return
# we need to find the spec from the name
with open(config_file, "r") as f:
config = yaml.safe_load(f)
config = PackageConfig(**yaml.safe_load(f))
dist = resolve_distribution_spec(config["spec"])
if dist is None:
raise ValueError(f"Could not find any registered spec `{config['spec']}`")
conda_env = config["conda_env"]
if not conda_env:
raise ValueError(
f"Could not find Conda environment for distribution `{args.name}`"
if not config.distribution_id:
# this is technically not necessary. everything else continues to work,
# but maybe we want to be very clear for the users
self.parser.error(
"No distribution_id found. Did you want to start a provider?"
)
return
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_distribution.sh",
)
args = [script, conda_env, config_file, "--port", str(args.port)] + (
["--disable-ipv6"] if args.disable_ipv6 else []
)
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_with_pty(args)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)

View file

@ -116,4 +116,12 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies"
printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name"
if [ "$api_or_stack" = "stack" ]; then
subcommand="stack"
target=""
else
subcommand="api"
target="$api_or_stack"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$env_name"

View file

@ -110,4 +110,12 @@ set +x
printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}"
echo "You can run it with: podman run -p 8000:8000 $image_name"
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$image_name"
if [ "$api_or_stack" = "stack" ]; then
subcommand="stack"
target=""
else
subcommand="api"
target="$api_or_stack"
fi
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$image_name"

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_toolchain.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.distribution.dynamic import instantiate_class_type
def configure_api_providers(existing_configs: Dict[str, Any]) -> None:
all_providers = api_providers()
provider_configs = {}
for api_str, stub_config in existing_configs.items():
api = Api(api_str)
providers = all_providers[api]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api_str}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except Exception:
existing_provider_config = None
provider_config = prompt_for_config(
config_type,
existing_provider_config,
)
print("")
provider_configs[api_str] = {
"provider_id": provider_id,
**provider_config.dict(),
}
return provider_configs

View file

@ -97,7 +97,7 @@ class RemoteProviderConfig(BaseModel):
def validate_url(cls, url: str) -> str:
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str:
@ -150,12 +150,13 @@ def remote_provider_spec(
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
distribution_id: str
description: str
provider_specs: Dict[Api, ProviderSpec] = Field(
docker_image: Optional[str] = None
providers: Dict[Api, str] = Field(
default_factory=dict,
description="Provider specifications for each of the APIs provided by this distribution",
description="Provider IDs for each of the APIs provided by this distribution",
)
@ -170,6 +171,8 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
this could be just a hash
""",
)
distribution_id: Optional[str] = None
docker_image: Optional[str] = Field(
default=None,
description="Reference to the docker image if this package refers to a container",

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
import pkg_resources
import yaml
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.distribution import api_providers
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
pip_packages = provider.pip_packages
for dep in dependencies.values():
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES,
)
class ApiInput(BaseModel):
api: Api
provider: str
dependencies: Dict[str, ProviderSpec]
def build_package(
api_inputs: List[ApiInput],
build_type: BuildType,
name: str,
distribution_id: Optional[str] = None,
docker_image: Optional[str] = None,
):
is_stack = len(api_inputs) > 1
if is_stack:
if not distribution_id:
raise ValueError(
"You must specify a distribution name when building the Llama Stack"
)
api1 = api_inputs[0]
provider = distribution_id if is_stack else api1.provider
api_or_stack = "stack" if is_stack else api1.api.value
build_desc = "image" if build_type == BuildType.container else "env"
build_dir = BUILDS_BASE_DIR / api_or_stack
os.makedirs(build_dir, exist_ok=True)
package_name = f"{build_desc}-{provider}-{name}"
package_name = package_name.replace("::", "-")
package_file = build_dir / f"{package_name}.yaml"
all_providers = api_providers()
package_deps = Dependencies(
docker_image=docker_image or "python:3.10-slim",
pip_packages=[],
)
stub_config = {}
for api_input in api_inputs:
api = api_input.api
providers_for_api = all_providers[api]
if api_input.provider not in providers_for_api:
raise ValueError(
f"Provider `{api_input.provider}` is not available for API `{api}`"
)
deps = get_dependencies(
providers_for_api[api_input.provider],
api_input.dependencies,
)
if deps.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
package_deps.pip_packages.extend(deps.pip_packages)
stub_config[api.value] = {"provider_id": api_input.provider}
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
for api_str, new_config in stub_config.items():
if api_str not in c.providers:
c.providers[api_str] = new_config
else:
existing_config = c.providers[api_str]
if existing_config["provider_id"] != new_config["provider_id"]:
cprint(
f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`",
color="yellow",
)
c.providers[api_str] = new_config
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.distribution_id = distribution_id
c.docker_image = package_name if build_type == BuildType.container else None
c.conda_env = package_name if build_type == BuildType.conda_env else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if build_type == BuildType.container:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
api_or_stack,
package_name,
package_deps.docker_image,
" ".join(package_deps.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
api_or_stack,
package_name,
" ".join(package_deps.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -8,77 +8,42 @@ from functools import lru_cache
from typing import List, Optional
from .datatypes import * # noqa: F403
from .distribution import api_providers
@lru_cache()
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
return [
DistributionSpec(
spec_id="local",
distribution_id="local",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
Api.memory: providers[Api.memory]["meta-reference-faiss"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
providers={
Api.inference: "meta-reference",
Api.memory: "meta-reference-faiss",
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
},
),
DistributionSpec(
spec_id="remote",
distribution_id="remote",
description="Point to remote services for all llama stack APIs",
provider_specs={x: remote_provider_spec(x) for x in providers},
providers={x: "remote" for x in Api},
),
DistributionSpec(
spec_id="local-ollama",
distribution_id="local-ollama",
description="Like local, but use ollama for running LLM inference",
provider_specs={
# this is ODD; make this easier -- we just need a better function to retrieve registered providers
Api.inference: providers[Api.inference][remote_provider_id("ollama")],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},
),
DistributionSpec(
spec_id="test-agentic",
description="Test agentic with others as remote",
provider_specs={
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
Api.inference: remote_provider_spec(Api.inference),
Api.memory: remote_provider_spec(Api.memory),
Api.safety: remote_provider_spec(Api.safety),
},
),
DistributionSpec(
spec_id="test-inference",
description="Test inference provider",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
},
),
DistributionSpec(
spec_id="test-memory",
description="Test memory provider",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},
),
DistributionSpec(
spec_id="test-safety",
description="Test safety provider",
provider_specs={
Api.safety: providers[Api.safety]["meta-reference"],
providers={
Api.inference: remote_provider_id("ollama"),
Api.safety: "meta-reference",
Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
},
),
]
@lru_cache()
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.spec_id == spec_id:
if spec.distribution_id == distribution_id:
return spec
return None