diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index b1bfd03b6..0a14bcf6a 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -16,13 +16,14 @@ from pydantic import BaseModel from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 -from .api import * # noqa: F403 +from llama_toolchain.distribution.datatypes import RemoteProviderConfig +from .api import * # noqa: F403 from .event_logger import EventLogger -async def get_client_impl(base_url: str): - return AgenticSystemClient(base_url) +async def get_client_impl(config: RemoteProviderConfig, _deps): + return AgenticSystemClient(config.url) def encodable_dict(d: BaseModel): diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index 248796ac6..05f59c19e 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -5,52 +5,12 @@ # the root directory of this source tree. import argparse -import json -import os -from pydantic import BaseModel -from datetime import datetime -from enum import Enum -from typing import Dict, List, Optional - -import pkg_resources -import yaml - -from termcolor import cprint +from typing import Dict from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.distribution.datatypes import * # noqa: F403 -class BuildType(Enum): - container = "container" - conda_env = "conda_env" - - -class Dependencies(BaseModel): - pip_packages: List[str] - docker_image: Optional[str] = None - - -def get_dependencies( - provider: ProviderSpec, dependencies: Dict[str, ProviderSpec] -) -> Dependencies: - from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES - - pip_packages = provider.pip_packages - for dep in dependencies.values(): - if dep.docker_image: - raise ValueError( - "You can only have the root provider specify a docker image" - ) - pip_packages.extend(dep.pip_packages) - - return Dependencies( - docker_image=provider.docker_image, - pip_packages=pip_packages + SERVER_DEPENDENCIES, - ) - - def parse_dependencies( dependencies: str, parser: argparse.ArgumentParser ) -> Dict[str, ProviderSpec]: @@ -89,6 +49,9 @@ class ApiBuild(Subcommand): def _add_arguments(self): from llama_toolchain.distribution.distribution import stack_apis + from llama_toolchain.distribution.package import ( + BuildType, + ) allowed_args = [a.name for a in stack_apis()] self.parser.add_argument( @@ -123,101 +86,20 @@ class ApiBuild(Subcommand): ) def _run_api_build_command(self, args: argparse.Namespace) -> None: - from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.distribution.distribution import api_providers - from llama_toolchain.common.serialize import EnumEncoder - - os.makedirs(BUILDS_BASE_DIR, exist_ok=True) - all_providers = api_providers() - - api = Api(args.api) - assert api in all_providers - - providers = all_providers[api] - if args.provider not in providers: - self.parser.error( - f"Provider `{args.provider}` is not available for API `{api}`" - ) - return - - if args.type == BuildType.container.value: - package_name = f"image-{args.provider}-{args.name}" - else: - package_name = f"env-{args.provider}-{args.name}" - package_name = package_name.replace("::", "-") - - build_dir = BUILDS_BASE_DIR / args.api - os.makedirs(build_dir, exist_ok=True) - - # get these names straight. too confusing. - provider_deps = parse_dependencies(args.dependencies or "", self.parser) - dependencies = get_dependencies(providers[args.provider], provider_deps) - - package_file = build_dir / f"{package_name}.yaml" - - stub_config = { - api.value: { - "provider_id": args.provider, - }, - **provider_deps, - } - - # properly handle the case where package exists but has - # inconsistent configuration for the providers. if possible, - # we don't want to overwrite the existing configuration. - if package_file.exists(): - cprint( - f"Build `{package_name}` exists; will reconfigure", - color="yellow", - ) - c = PackageConfig(**yaml.safe_load(package_file.read_text())) - else: - c = PackageConfig( - built_at=datetime.now(), - package_name=package_name, - providers=stub_config, - ) - - c.docker_image = ( - package_name if args.type == BuildType.container.value else None + from llama_toolchain.distribution.package import ( + ApiInput, + BuildType, + build_package, ) - c.conda_env = package_name if args.type == BuildType.conda_env.value else None - with open(package_file, "w") as f: - to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder)) - f.write(yaml.dump(to_write, sort_keys=False)) - - if args.type == BuildType.container.value: - script = pkg_resources.resource_filename( - "llama_toolchain", "distribution/build_container.sh" - ) - args = [ - script, - args.api, - package_name, - dependencies.docker_image or "python:3.10-slim", - " ".join(dependencies.pip_packages), - ] - else: - script = pkg_resources.resource_filename( - "llama_toolchain", "distribution/build_conda_env.sh" - ) - args = [ - script, - args.api, - package_name, - " ".join(dependencies.pip_packages), - ] - - return_code = run_with_pty(args) - if return_code != 0: - cprint( - f"Failed to build target {package_name} with return code {return_code}", - color="red", - ) - return - - cprint( - f"Target `{package_name}` built with configuration at {str(package_file)}", - color="green", + api_input = ApiInput( + api=Api(args.api), + provider=args.provider, + dependencies=parse_dependencies(args.dependencies or "", self.parser), + ) + + build_package( + [api_input], + build_type=BuildType(args.type), + name=args.name, ) diff --git a/llama_toolchain/cli/api/configure.py b/llama_toolchain/cli/api/configure.py index fe68a2a12..ef48f175a 100644 --- a/llama_toolchain/cli/api/configure.py +++ b/llama_toolchain/cli/api/configure.py @@ -14,7 +14,6 @@ import yaml from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR from llama_toolchain.distribution.datatypes import * # noqa: F403 -from termcolor import cprint class ApiConfigure(Subcommand): @@ -41,14 +40,14 @@ class ApiConfigure(Subcommand): help="Stack API (one of: {})".format(", ".join(allowed_args)), ) self.parser.add_argument( - "--name", + "--build-name", type=str, help="Name of the provider build to fully configure", required=True, ) def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: - name = args.name + name = args.build_name if not name.endswith(".yaml"): name += ".yaml" config_file = BUILDS_BASE_DIR / args.api / name @@ -62,48 +61,14 @@ class ApiConfigure(Subcommand): def configure_llama_provider(config_file: Path) -> None: - from llama_toolchain.common.prompt_for_config import prompt_for_config from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.distribution.distribution import api_providers - from llama_toolchain.distribution.dynamic import instantiate_class_type + from llama_toolchain.distribution.configure import configure_api_providers with open(config_file, "r") as f: config = PackageConfig(**yaml.safe_load(f)) - all_providers = api_providers() + config.providers = configure_api_providers(config.providers) - provider_configs = {} - for api, stub_config in config.providers.items(): - providers = all_providers[Api(api)] - provider_id = stub_config["provider_id"] - if provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - - provider_spec = providers[provider_id] - cprint( - f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"] - ) - config_type = instantiate_class_type(provider_spec.config_class) - - try: - existing_provider_config = config_type(**stub_config) - except Exception: - existing_provider_config = None - - provider_config = prompt_for_config( - config_type, - existing_provider_config, - ) - print("") - - provider_configs[api] = { - "provider_id": provider_id, - **provider_config.dict(), - } - - config.providers = provider_configs with open(config_file, "w") as fp: to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) fp.write(yaml.dump(to_write, sort_keys=False)) diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py new file mode 100644 index 000000000..ef2393c09 --- /dev/null +++ b/llama_toolchain/cli/stack/build.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +from typing import Dict + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.datatypes import * # noqa: F403 + + +class StackBuild(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "build", + prog="llama stack build", + description="Build a Llama stack container", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_stack_build_command) + + def _add_arguments(self): + from llama_toolchain.distribution.registry import available_distribution_specs + from llama_toolchain.distribution.package import ( + BuildType, + ) + + allowed_ids = [d.distribution_id for d in available_distribution_specs()] + self.parser.add_argument( + "distribution", + type=str, + choices=allowed_ids, + help="Distribution to build (one of: {})".format(allowed_ids), + ) + + self.parser.add_argument( + "--name", + type=str, + help="Name of the build target (image, conda env)", + required=True, + ) + self.parser.add_argument( + "--type", + type=str, + default="container", + choices=[v.value for v in BuildType], + ) + + def _run_stack_build_command(self, args: argparse.Namespace) -> None: + from llama_toolchain.distribution.registry import resolve_distribution_spec + from llama_toolchain.distribution.package import ( + ApiInput, + BuildType, + build_package, + ) + + dist = resolve_distribution_spec(args.distribution) + if dist is None: + self.parser.error(f"Could not find distribution {args.distribution}") + return + + api_inputs = [] + for api, provider_id in dist.providers.items(): + api_inputs.append( + ApiInput( + api=api, + provider=provider_id, + dependencies={}, + ) + ) + + build_package( + api_inputs, + build_type=BuildType(args.type), + name=args.name, + distribution_id=args.distribution, + docker_image=dist.docker_image, + ) diff --git a/llama_toolchain/cli/stack/configure.py b/llama_toolchain/cli/stack/configure.py index 0e311feaa..2e62238c2 100644 --- a/llama_toolchain/cli/stack/configure.py +++ b/llama_toolchain/cli/stack/configure.py @@ -6,13 +6,14 @@ import argparse import json -import shlex +from pathlib import Path import yaml from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR +from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from llama_toolchain.distribution.datatypes import * # noqa: F403 class StackConfigure(Subcommand): @@ -22,85 +23,61 @@ class StackConfigure(Subcommand): super().__init__() self.parser = subparsers.add_parser( "configure", - prog="llama distribution configure", + prog="llama stack configure", description="configure a llama stack distribution", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() - self.parser.set_defaults(func=self._run_distribution_configure_cmd) + self.parser.set_defaults(func=self._run_stack_configure_cmd) def _add_arguments(self): self.parser.add_argument( - "--name", + "--build-name", type=str, - help="Name of the distribution to configure", + help="Name of the stack build to configure", required=True, ) - def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.distribution.datatypes import StackConfig - from llama_toolchain.distribution.registry import resolve_distribution_spec + def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: + name = args.build_name + if not name.endswith(".yaml"): + name += ".yaml" - config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" + config_file = BUILDS_BASE_DIR / "stack" / name if not config_file.exists(): self.parser.error( - f"Could not find {config_file}. Please run `llama distribution install` first" + f"Could not find {config_file}. Please run `llama stack build` first" ) return - # we need to find the spec from the name - with open(config_file, "r") as f: - config = StackConfig(**yaml.safe_load(f)) - - dist = resolve_distribution_spec(config.spec) - if dist is None: - raise ValueError(f"Could not find any registered spec `{config.spec}`") - - configure_llama_distribution(dist, config) + configure_llama_distribution(config_file) -def configure_llama_distribution(dist: "Stack", config: "StackConfig"): - from llama_toolchain.common.exec import run_command - from llama_toolchain.common.prompt_for_config import prompt_for_config +def configure_llama_distribution(config_file: Path) -> None: from llama_toolchain.common.serialize import EnumEncoder - from llama_toolchain.distribution.dynamic import instantiate_class_type + from llama_toolchain.distribution.configure import configure_api_providers + from llama_toolchain.distribution.registry import resolve_distribution_spec - python_exe = run_command(shlex.split("which python")) - # simple check - conda_env = config.conda_env - if conda_env not in python_exe: + with open(config_file, "r") as f: + config = PackageConfig(**yaml.safe_load(f)) + + dist = resolve_distribution_spec(config.distribution_id) + if dist is None: raise ValueError( - f"Please re-run configure by activating the `{conda_env}` conda environment" + f"Could not find any registered distribution `{config.distribution_id}`" ) if config.providers: cprint( - f"Configuration already exists for {config.name}. Will overwrite...", + f"Configuration already exists for {config.distribution_id}. Will overwrite...", "yellow", attrs=["bold"], ) - for api, provider_spec in dist.provider_specs.items(): - cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) - config_type = instantiate_class_type(provider_spec.config_class) - provider_config = prompt_for_config( - config_type, - ( - config_type(**config.providers[api.value]) - if api.value in config.providers - else None - ), - ) - print("") + config.providers = configure_api_providers(config.providers) - config.providers[api.value] = { - "provider_id": provider_spec.provider_id, - **provider_config.dict(), - } + with open(config_file, "w") as fp: + to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) + fp.write(yaml.dump(to_write, sort_keys=False)) - config_path = DISTRIBS_BASE_DIR / config.name / "config.yaml" - with open(config_path, "w") as fp: - dist_config = json.loads(json.dumps(config.dict(), cls=EnumEncoder)) - fp.write(yaml.dump(dist_config, sort_keys=False)) - - print(f"YAML configuration has been written to {config_path}") + print(f"YAML configuration has been written to {config_file}") diff --git a/llama_toolchain/cli/stack/create.py b/llama_toolchain/cli/stack/create.py deleted file mode 100644 index 037681355..000000000 --- a/llama_toolchain/cli/stack/create.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse - -from llama_toolchain.cli.subcommand import Subcommand - - -class StackCreate(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "create", - prog="llama distribution create", - description="create a Llama stack distribution", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_distribution_create_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "--name", - type=str, - help="Name of the distribution to create", - required=True, - ) - # for each Api the user wants to support, we should - # get the list of available providers, ask which one the user - # wants to pick and then ask for their configuration. - - def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.distribution.registry import resolve_distribution_spec - - dist = resolve_distribution_spec(args.name) - if dist is not None: - self.parser.error(f"Stack with name {args.name} already exists") - return - - raise NotImplementedError() diff --git a/llama_toolchain/cli/stack/install.py b/llama_toolchain/cli/stack/install.py deleted file mode 100644 index 9e0a12de9..000000000 --- a/llama_toolchain/cli/stack/install.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import os - -import pkg_resources -import yaml - -from termcolor import cprint - -from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR - - -class StackInstall(Subcommand): - """Llama cli for configuring llama toolchain configs""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "install", - prog="llama distribution install", - description="Install a llama stack distribution", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_distribution_install_cmd) - - def _add_arguments(self): - from llama_toolchain.distribution.registry import available_distribution_specs - - self.parser.add_argument( - "--spec", - type=str, - help="Stack spec to install (try local-ollama)", - required=True, - choices=[d.spec_id for d in available_distribution_specs()], - ) - self.parser.add_argument( - "--name", - type=str, - help="What should the installation be called locally?", - required=True, - ) - self.parser.add_argument( - "--conda-env", - type=str, - help="conda env in which this distribution will run (default = distribution name)", - ) - - def _run_distribution_install_cmd(self, args: argparse.Namespace) -> None: - from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.distribution.datatypes import StackConfig - from llama_toolchain.distribution.distribution import distribution_dependencies - from llama_toolchain.distribution.registry import resolve_distribution_spec - - os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True) - script = pkg_resources.resource_filename( - "llama_toolchain", - "distribution/install_distribution.sh", - ) - - dist = resolve_distribution_spec(args.spec) - if dist is None: - self.parser.error(f"Could not find distribution {args.spec}") - return - - distrib_dir = DISTRIBS_BASE_DIR / args.name - os.makedirs(distrib_dir, exist_ok=True) - - deps = distribution_dependencies(dist) - if not args.conda_env: - print(f"Using {args.name} as the Conda environment for this distribution") - - conda_env = args.conda_env or args.name - - config_file = distrib_dir / "config.yaml" - if config_file.exists(): - c = StackConfig(**yaml.safe_load(config_file.read_text())) - if c.spec != dist.spec_id: - self.parser.error( - f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`" - ) - return - if c.conda_env != conda_env: - self.parser.error( - f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`" - ) - return - else: - with open(config_file, "w") as f: - c = StackConfig( - spec=dist.spec_id, - name=args.name, - conda_env=conda_env, - ) - f.write(yaml.dump(c.dict(), sort_keys=False)) - - return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)]) - - assert return_code == 0, cprint( - f"Failed to install distribution {dist.spec_id}", color="red" - ) - cprint( - f"Stack `{args.name}` (with spec {dist.spec_id}) has been installed successfully!", - color="green", - ) diff --git a/llama_toolchain/cli/stack/list.py b/llama_toolchain/cli/stack/list.py index 217d5ae03..d321947e1 100644 --- a/llama_toolchain/cli/stack/list.py +++ b/llama_toolchain/cli/stack/list.py @@ -14,9 +14,9 @@ class StackList(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( - "list", - prog="llama distribution list", - description="Show available llama stack distributions", + "list-distributions", + prog="llama stack list-distributions", + description="Show available Llama Stack Distributions", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() @@ -31,17 +31,17 @@ class StackList(Subcommand): # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ - "Spec ID", - "ProviderSpecs", + "Distribution ID", + "Providers", "Description", ] rows = [] for spec in available_distribution_specs(): - providers = {k.value: v.provider_id for k, v in spec.provider_specs.items()} + providers = {k.value: v for k, v in spec.providers.items()} rows.append( [ - spec.spec_id, + spec.distribution_id, json.dumps(providers, indent=2), spec.description, ] diff --git a/llama_toolchain/cli/stack/stack.py b/llama_toolchain/cli/stack/stack.py index 0584f62da..a24cc5f09 100644 --- a/llama_toolchain/cli/stack/stack.py +++ b/llama_toolchain/cli/stack/stack.py @@ -8,9 +8,8 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand +from .build import StackBuild from .configure import StackConfigure -from .create import StackCreate -from .install import StackInstall from .list import StackList from .start import StackStart @@ -19,16 +18,15 @@ class StackParser(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() self.parser = subparsers.add_parser( - "distribution", - prog="llama distribution", - description="Operate on llama stack distributions", + "stack", + prog="llama stack", + description="Operations for the Llama Stack / Distributions", ) - subparsers = self.parser.add_subparsers(title="distribution_subcommands") + subparsers = self.parser.add_subparsers(title="stack_subcommands") # Add sub-commands - StackList.create(subparsers) - StackInstall.create(subparsers) - StackCreate.create(subparsers) + StackBuild.create(subparsers) StackConfigure.create(subparsers) + StackList.create(subparsers) StackStart.create(subparsers) diff --git a/llama_toolchain/cli/stack/start.py b/llama_toolchain/cli/stack/start.py index 9f35e5aa3..7b5ff5912 100644 --- a/llama_toolchain/cli/stack/start.py +++ b/llama_toolchain/cli/stack/start.py @@ -6,11 +6,13 @@ import argparse +from pathlib import Path + import pkg_resources import yaml from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR +from llama_toolchain.distribution.datatypes import * # noqa: F403 class StackStart(Subcommand): @@ -18,19 +20,18 @@ class StackStart(Subcommand): super().__init__() self.parser = subparsers.add_parser( "start", - prog="llama distribution start", - description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""", + prog="llama stack start", + description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""", formatter_class=argparse.RawTextHelpFormatter, ) self._add_arguments() - self.parser.set_defaults(func=self._run_distribution_start_cmd) + self.parser.set_defaults(func=self._run_stack_start_cmd) def _add_arguments(self): self.parser.add_argument( - "--name", + "yaml_config", type=str, - help="Name of the distribution to start", - required=True, + help="Yaml config containing the API build configuration", ) self.parser.add_argument( "--port", @@ -45,37 +46,45 @@ class StackStart(Subcommand): default=False, ) - def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: + def _run_stack_start_cmd(self, args: argparse.Namespace) -> None: from llama_toolchain.common.exec import run_with_pty - from llama_toolchain.distribution.registry import resolve_distribution_spec - config_file = DISTRIBS_BASE_DIR / args.name / "config.yaml" + config_file = Path(args.yaml_config) if not config_file.exists(): self.parser.error( - f"Could not find {config_file}. Please run `llama distribution install` first" + f"Could not find {config_file}. Please run `llama stack build` first" ) return - # we need to find the spec from the name with open(config_file, "r") as f: - config = yaml.safe_load(f) + config = PackageConfig(**yaml.safe_load(f)) - dist = resolve_distribution_spec(config["spec"]) - if dist is None: - raise ValueError(f"Could not find any registered spec `{config['spec']}`") - - conda_env = config["conda_env"] - if not conda_env: - raise ValueError( - f"Could not find Conda environment for distribution `{args.name}`" + if not config.distribution_id: + # this is technically not necessary. everything else continues to work, + # but maybe we want to be very clear for the users + self.parser.error( + "No distribution_id found. Did you want to start a provider?" ) + return - script = pkg_resources.resource_filename( - "llama_toolchain", - "distribution/start_distribution.sh", - ) - args = [script, conda_env, config_file, "--port", str(args.port)] + ( - ["--disable-ipv6"] if args.disable_ipv6 else [] - ) + if config.docker_image: + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/start_container.sh", + ) + run_args = [script, config.docker_image] + else: + script = pkg_resources.resource_filename( + "llama_toolchain", + "distribution/start_conda_env.sh", + ) + run_args = [ + script, + config.conda_env, + ] - run_with_pty(args) + run_args.extend([str(config_file), str(args.port)]) + if args.disable_ipv6: + run_args.append("--disable-ipv6") + + run_with_pty(run_args) diff --git a/llama_toolchain/distribution/build_conda_env.sh b/llama_toolchain/distribution/build_conda_env.sh index 38d10a525..2e103202d 100755 --- a/llama_toolchain/distribution/build_conda_env.sh +++ b/llama_toolchain/distribution/build_conda_env.sh @@ -116,4 +116,12 @@ ensure_conda_env_python310 "$env_name" "$pip_dependencies" printf "${GREEN}Successfully setup conda environment. Configuring build...${NC}" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name" +if [ "$api_or_stack" = "stack" ]; then + subcommand="stack" + target="" +else + subcommand="api" + target="$api_or_stack" +fi + +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$env_name" diff --git a/llama_toolchain/distribution/build_container.sh b/llama_toolchain/distribution/build_container.sh index 95551abac..d77de165c 100755 --- a/llama_toolchain/distribution/build_container.sh +++ b/llama_toolchain/distribution/build_container.sh @@ -110,4 +110,12 @@ set +x printf "${GREEN}Succesfully setup Podman image. Configuring build...${NC}" echo "You can run it with: podman run -p 8000:8000 $image_name" -$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$image_name" +if [ "$api_or_stack" = "stack" ]; then + subcommand="stack" + target="" +else + subcommand="api" + target="$api_or_stack" +fi + +$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama $subcommand configure $target --build-name "$image_name" diff --git a/llama_toolchain/distribution/configure.py b/llama_toolchain/distribution/configure.py new file mode 100644 index 000000000..f3bc9a1ab --- /dev/null +++ b/llama_toolchain/distribution/configure.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_toolchain.distribution.datatypes import * # noqa: F403 +from termcolor import cprint + +from llama_toolchain.common.prompt_for_config import prompt_for_config +from llama_toolchain.distribution.distribution import api_providers +from llama_toolchain.distribution.dynamic import instantiate_class_type + + +def configure_api_providers(existing_configs: Dict[str, Any]) -> None: + all_providers = api_providers() + + provider_configs = {} + for api_str, stub_config in existing_configs.items(): + api = Api(api_str) + providers = all_providers[api] + provider_id = stub_config["provider_id"] + if provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api_str}`" + ) + + provider_spec = providers[provider_id] + cprint(f"Configuring API: {api_str} ({provider_id})", "white", attrs=["bold"]) + config_type = instantiate_class_type(provider_spec.config_class) + + try: + existing_provider_config = config_type(**stub_config) + except Exception: + existing_provider_config = None + + provider_config = prompt_for_config( + config_type, + existing_provider_config, + ) + print("") + + provider_configs[api_str] = { + "provider_id": provider_id, + **provider_config.dict(), + } + + return provider_configs diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index f286ea5ce..468310542 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -97,7 +97,7 @@ class RemoteProviderConfig(BaseModel): def validate_url(cls, url: str) -> str: if not url.startswith("http"): raise ValueError(f"URL must start with http: {url}") - return url + return url.rstrip("/") def remote_provider_id(adapter_id: str) -> str: @@ -150,12 +150,13 @@ def remote_provider_spec( @json_schema_type class DistributionSpec(BaseModel): - spec_id: str + distribution_id: str description: str - provider_specs: Dict[Api, ProviderSpec] = Field( + docker_image: Optional[str] = None + providers: Dict[Api, str] = Field( default_factory=dict, - description="Provider specifications for each of the APIs provided by this distribution", + description="Provider IDs for each of the APIs provided by this distribution", ) @@ -170,6 +171,8 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p this could be just a hash """, ) + distribution_id: Optional[str] = None + docker_image: Optional[str] = Field( default=None, description="Reference to the docker image if this package refers to a container", diff --git a/llama_toolchain/distribution/package.py b/llama_toolchain/distribution/package.py new file mode 100644 index 000000000..7b4cf56ca --- /dev/null +++ b/llama_toolchain/distribution/package.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional + +import pkg_resources +import yaml +from pydantic import BaseModel + +from termcolor import cprint + +from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR +from llama_toolchain.distribution.datatypes import * # noqa: F403 + +from llama_toolchain.common.exec import run_with_pty +from llama_toolchain.common.serialize import EnumEncoder +from llama_toolchain.distribution.distribution import api_providers + + +class BuildType(Enum): + container = "container" + conda_env = "conda_env" + + +class Dependencies(BaseModel): + pip_packages: List[str] + docker_image: Optional[str] = None + + +def get_dependencies( + provider: ProviderSpec, dependencies: Dict[str, ProviderSpec] +) -> Dependencies: + from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES + + pip_packages = provider.pip_packages + for dep in dependencies.values(): + if dep.docker_image: + raise ValueError( + "You can only have the root provider specify a docker image" + ) + pip_packages.extend(dep.pip_packages) + + return Dependencies( + docker_image=provider.docker_image, + pip_packages=pip_packages + SERVER_DEPENDENCIES, + ) + + +class ApiInput(BaseModel): + api: Api + provider: str + dependencies: Dict[str, ProviderSpec] + + +def build_package( + api_inputs: List[ApiInput], + build_type: BuildType, + name: str, + distribution_id: Optional[str] = None, + docker_image: Optional[str] = None, +): + is_stack = len(api_inputs) > 1 + if is_stack: + if not distribution_id: + raise ValueError( + "You must specify a distribution name when building the Llama Stack" + ) + + api1 = api_inputs[0] + + provider = distribution_id if is_stack else api1.provider + api_or_stack = "stack" if is_stack else api1.api.value + build_desc = "image" if build_type == BuildType.container else "env" + + build_dir = BUILDS_BASE_DIR / api_or_stack + os.makedirs(build_dir, exist_ok=True) + + package_name = f"{build_desc}-{provider}-{name}" + package_name = package_name.replace("::", "-") + package_file = build_dir / f"{package_name}.yaml" + + all_providers = api_providers() + + package_deps = Dependencies( + docker_image=docker_image or "python:3.10-slim", + pip_packages=[], + ) + stub_config = {} + for api_input in api_inputs: + api = api_input.api + providers_for_api = all_providers[api] + if api_input.provider not in providers_for_api: + raise ValueError( + f"Provider `{api_input.provider}` is not available for API `{api}`" + ) + + deps = get_dependencies( + providers_for_api[api_input.provider], + api_input.dependencies, + ) + if deps.docker_image: + raise ValueError("A stack's dependencies cannot have a docker image") + package_deps.pip_packages.extend(deps.pip_packages) + + stub_config[api.value] = {"provider_id": api_input.provider} + + if package_file.exists(): + cprint( + f"Build `{package_name}` exists; will reconfigure", + color="yellow", + ) + c = PackageConfig(**yaml.safe_load(package_file.read_text())) + for api_str, new_config in stub_config.items(): + if api_str not in c.providers: + c.providers[api_str] = new_config + else: + existing_config = c.providers[api_str] + if existing_config["provider_id"] != new_config["provider_id"]: + cprint( + f"Provider `{api_str}` has changed from `{existing_config}` to `{new_config}`", + color="yellow", + ) + c.providers[api_str] = new_config + else: + c = PackageConfig( + built_at=datetime.now(), + package_name=package_name, + providers=stub_config, + ) + + c.distribution_id = distribution_id + c.docker_image = package_name if build_type == BuildType.container else None + c.conda_env = package_name if build_type == BuildType.conda_env else None + + with open(package_file, "w") as f: + to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder)) + f.write(yaml.dump(to_write, sort_keys=False)) + + if build_type == BuildType.container: + script = pkg_resources.resource_filename( + "llama_toolchain", "distribution/build_container.sh" + ) + args = [ + script, + api_or_stack, + package_name, + package_deps.docker_image, + " ".join(package_deps.pip_packages), + ] + else: + script = pkg_resources.resource_filename( + "llama_toolchain", "distribution/build_conda_env.sh" + ) + args = [ + script, + api_or_stack, + package_name, + " ".join(package_deps.pip_packages), + ] + + return_code = run_with_pty(args) + if return_code != 0: + cprint( + f"Failed to build target {package_name} with return code {return_code}", + color="red", + ) + return + + cprint( + f"Target `{package_name}` built with configuration at {str(package_file)}", + color="green", + ) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index e1d49eb05..4a8ec5940 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -8,77 +8,42 @@ from functools import lru_cache from typing import List, Optional from .datatypes import * # noqa: F403 -from .distribution import api_providers @lru_cache() def available_distribution_specs() -> List[DistributionSpec]: - providers = api_providers() return [ DistributionSpec( - spec_id="local", + distribution_id="local", description="Use code from `llama_toolchain` itself to serve all llama stack APIs", - provider_specs={ - Api.inference: providers[Api.inference]["meta-reference"], - Api.memory: providers[Api.memory]["meta-reference-faiss"], - Api.safety: providers[Api.safety]["meta-reference"], - Api.agentic_system: providers[Api.agentic_system]["meta-reference"], + providers={ + Api.inference: "meta-reference", + Api.memory: "meta-reference-faiss", + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", }, ), DistributionSpec( - spec_id="remote", + distribution_id="remote", description="Point to remote services for all llama stack APIs", - provider_specs={x: remote_provider_spec(x) for x in providers}, + providers={x: "remote" for x in Api}, ), DistributionSpec( - spec_id="local-ollama", + distribution_id="local-ollama", description="Like local, but use ollama for running LLM inference", - provider_specs={ - # this is ODD; make this easier -- we just need a better function to retrieve registered providers - Api.inference: providers[Api.inference][remote_provider_id("ollama")], - Api.safety: providers[Api.safety]["meta-reference"], - Api.agentic_system: providers[Api.agentic_system]["meta-reference"], - Api.memory: providers[Api.memory]["meta-reference-faiss"], - }, - ), - DistributionSpec( - spec_id="test-agentic", - description="Test agentic with others as remote", - provider_specs={ - Api.agentic_system: providers[Api.agentic_system]["meta-reference"], - Api.inference: remote_provider_spec(Api.inference), - Api.memory: remote_provider_spec(Api.memory), - Api.safety: remote_provider_spec(Api.safety), - }, - ), - DistributionSpec( - spec_id="test-inference", - description="Test inference provider", - provider_specs={ - Api.inference: providers[Api.inference]["meta-reference"], - }, - ), - DistributionSpec( - spec_id="test-memory", - description="Test memory provider", - provider_specs={ - Api.inference: providers[Api.inference]["meta-reference"], - Api.memory: providers[Api.memory]["meta-reference-faiss"], - }, - ), - DistributionSpec( - spec_id="test-safety", - description="Test safety provider", - provider_specs={ - Api.safety: providers[Api.safety]["meta-reference"], + providers={ + Api.inference: remote_provider_id("ollama"), + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", + Api.memory: "meta-reference-faiss", }, ), ] @lru_cache() -def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]: +def resolve_distribution_spec(distribution_id: str) -> Optional[DistributionSpec]: for spec in available_distribution_specs(): - if spec.spec_id == spec_id: + if spec.distribution_id == distribution_id: return spec return None