diff --git a/llama_stack/cli/stack/_show.py b/llama_stack/cli/stack/_show.py new file mode 100644 index 000000000..fb7160b7d --- /dev/null +++ b/llama_stack/cli/stack/_show.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import importlib.resources +import json +import os +import shutil +import sys +import textwrap +from pathlib import Path + +import yaml +from prompt_toolkit import prompt +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.validation import Validator +from termcolor import colored, cprint + +from llama_stack.cli.stack.utils import ImageType, available_templates_specs, generate_run_config +from llama_stack.core.build import get_provider_dependencies +from llama_stack.core.datatypes import ( + BuildConfig, + BuildProvider, + DistributionSpec, +) +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.external import load_external_apis +from llama_stack.core.stack import replace_env_vars +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.core.utils.exec import run_command +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import Api + +TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" + +logger = get_logger(name=__name__, category="cli") + + +# These are the dependencies needed by the distribution server. +# `llama-stack` is automatically installed by the installation script. +SERVER_DEPENDENCIES = [ + "aiosqlite", + "fastapi", + "fire", + "httpx", + "uvicorn", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-http", +] + + +def run_stack_show_command(args: argparse.Namespace) -> None: + current_venv = os.environ.get("VIRTUAL_ENV") + env_name = args.env_name or current_venv + + if args.distro: + available_templates = available_templates_specs() + if args.distro not in available_templates: + cprint( + f"Could not find template {args.distro}. Please run `llama stack show --list-distros` to check out the available templates", + color="red", + file=sys.stderr, + ) + sys.exit(1) + build_config = available_templates[args.distro] + # always venv, conda is gone and container is separate. + build_config.image_type = ImageType.VENV.value + elif args.providers: + provider_list: dict[str, list[BuildProvider]] = dict() + for api_provider in args.providers.split(","): + if "=" not in api_provider: + cprint( + "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", + color="red", + file=sys.stderr, + ) + sys.exit(1) + api, provider_type = api_provider.split("=") + providers_for_api = get_provider_registry().get(Api(api), None) + if providers_for_api is None: + cprint( + f"{api} is not a valid API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + if provider_type in providers_for_api: + provider = BuildProvider( + provider_type=provider_type, + module=None, + ) + provider_list.setdefault(api, []).append(provider) + else: + cprint( + f"{provider} is not a valid provider for the {api} API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + distribution_spec = DistributionSpec( + providers=provider_list, + description=",".join(args.providers), + ) + build_config = BuildConfig(image_type=ImageType.VENV.value, distribution_spec=distribution_spec) + elif not args.config and not args.distro: + name = prompt( + "> Enter a name for your Llama Stack (e.g. my-local-stack): ", + validator=Validator.from_callable( + lambda x: len(x) > 0, + error_message="Name cannot be empty, please enter a name", + ), + ) + + image_type = prompt( + "> Enter the image type you want your Llama Stack to be built as (use to see options): ", + completer=WordCompleter([e.value for e in ImageType]), + complete_while_typing=True, + validator=Validator.from_callable( + lambda x: x in [e.value for e in ImageType], + error_message="Invalid image type. Use to see options", + ), + ) + + env_name = f"llamastack-{name}" + + cprint( + textwrap.dedent( + """ + Llama Stack is composed of several APIs working together. Let's select + the provider types (implementations) you want to use for these APIs. + """, + ), + color="green", + file=sys.stderr, + ) + + cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) + + providers: dict[str, list[BuildProvider]] = dict() + for api, providers_for_api in get_provider_registry().items(): + available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] + if not available_providers: + continue + api_provider = prompt( + f"> Enter provider for API {api.value}: ", + completer=WordCompleter(available_providers), + complete_while_typing=True, + validator=Validator.from_callable( + lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847 + error_message="Invalid provider, use to see options", + ), + ) + + string_providers = api_provider.split(" ") + + for provider in string_providers: + providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider)) + + description = prompt( + "\n > (Optional) Enter a short description for your Llama Stack: ", + default="", + ) + + distribution_spec = DistributionSpec( + providers=providers, + description=description, + ) + + build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec) + else: + with open(args.config) as f: + try: + contents = yaml.safe_load(f) + contents = replace_env_vars(contents) + build_config = BuildConfig(**contents) + build_config.image_type = "venv" + except Exception as e: + cprint( + f"Could not parse config file {args.config}: {e}", + color="red", + file=sys.stderr, + ) + sys.exit(1) + + print(f"# Dependencies for {args.distro or args.config or env_name}") + + normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config) + normal_deps += SERVER_DEPENDENCIES + + # Quote deps with commas + quoted_normal_deps = [quote_if_needed(dep) for dep in normal_deps] + print(f"uv pip install {' '.join(quoted_normal_deps)}") + + for special_dep in special_deps: + print(f"uv pip install {quote_if_needed(special_dep)}") + + for external_dep in external_provider_dependencies: + print(f"uv pip install {quote_if_needed(external_dep)}") + + +def quote_if_needed(dep): + # Add quotes if the dependency contains a comma (likely version specifier) + return f"'{dep}'" if "," in dep else dep diff --git a/llama_stack/cli/stack/show.py b/llama_stack/cli/stack/show.py new file mode 100644 index 000000000..9c8c2c90e --- /dev/null +++ b/llama_stack/cli/stack/show.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import argparse +import textwrap + +from llama_stack.cli.stack.utils import ImageType +from llama_stack.cli.subcommand import Subcommand + + +class StackShow(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "show", + prog="llama stack show", + description="show the dependencies for a llama stack distribution", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_stack_show_command) + + def _add_arguments(self): + self.parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively", + ) + + self.parser.add_argument( + "--distro", + type=str, + default=None, + help="Name of the distro config to use for show. You may use `llama stack show --list-distros` to check out the available distros", + ) + + self.parser.add_argument( + "--list-distros", + action="store_true", + default=False, + help="Show the available templates for building a Llama Stack distribution", + ) + + self.parser.add_argument( + "--env-name", + type=str, + help=textwrap.dedent( + f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for +the build. If not specified, currently active environment will be used if found. + """ + ), + default=None, + ) + self.parser.add_argument( + "--print-deps-only", + default=False, + action="store_true", + help="Print the dependencies for the stack only, without building the stack", + ) + self.parser.add_argument( + "--providers", + type=str, + default=None, + help="sync dependencies for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.", + ) + + def _run_stack_show_command(self, args: argparse.Namespace) -> None: + # always keep implementation completely silo-ed away from CLI so CLI + # can be fast to load and reduces dependencies + from ._show import run_stack_show_command + + return run_stack_show_command(args) diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index 3aff78e23..85365989c 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -11,11 +11,11 @@ from llama_stack.cli.stack.list_stacks import StackListBuilds from llama_stack.cli.stack.utils import print_subcommand_description from llama_stack.cli.subcommand import Subcommand -from .build import StackBuild from .list_apis import StackListApis from .list_providers import StackListProviders from .remove import StackRemove from .run import StackRun +from .show import StackShow class StackParser(Subcommand): @@ -39,7 +39,7 @@ class StackParser(Subcommand): subparsers = self.parser.add_subparsers(title="stack_subcommands") # Add sub-commands - StackBuild.create(subparsers) + StackShow.create(subparsers) StackListApis.create(subparsers) StackListProviders.create(subparsers) StackRun.create(subparsers) diff --git a/llama_stack/cli/stack/utils.py b/llama_stack/cli/stack/utils.py index fdf9e1761..4d4c1b538 100644 --- a/llama_stack/cli/stack/utils.py +++ b/llama_stack/cli/stack/utils.py @@ -4,7 +4,28 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +import sys from enum import Enum +from functools import lru_cache +from pathlib import Path + +import yaml +from termcolor import cprint + +from llama_stack.core.datatypes import ( + BuildConfig, + Provider, + StackRunConfig, +) +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.resolver import InvalidProviderError +from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack.core.utils.image_types import LlamaStackImageType +from llama_stack.providers.datatypes import Api + +TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions" class ImageType(Enum): @@ -19,3 +40,91 @@ def print_subcommand_description(parser, subparsers): description = subcommand.description description_text += f" {name:<21} {description}\n" parser.epilog = description_text + + +def generate_run_config( + build_config: BuildConfig, + build_dir: Path, + image_name: str, +) -> Path: + """ + Generate a run.yaml template file for user to edit from a build.yaml file + """ + apis = list(build_config.distribution_spec.providers.keys()) + run_config = StackRunConfig( + container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None), + image_name=image_name, + apis=apis, + providers={}, + external_providers_dir=build_config.external_providers_dir + if build_config.external_providers_dir + else EXTERNAL_PROVIDERS_DIR, + ) + # build providers dict + provider_registry = get_provider_registry(build_config) + for api in apis: + run_config.providers[api] = [] + providers = build_config.distribution_spec.providers[api] + + for provider in providers: + pid = provider.provider_type.split("::")[-1] + + p = provider_registry[Api(api)][provider.provider_type] + if p.deprecation_error: + raise InvalidProviderError(p.deprecation_error) + + try: + config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class) + except (ModuleNotFoundError, ValueError) as exc: + # HACK ALERT: + # This code executes after building is done, the import cannot work since the + # package is either available in the venv or container - not available on the host. + # TODO: use a "is_external" flag in ProviderSpec to check if the provider is + # external + cprint( + f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}", + color="yellow", + file=sys.stderr, + ) + # Set config_type to None to avoid UnboundLocalError + config_type = None + + if config_type is not None and hasattr(config_type, "sample_run_config"): + config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}") + else: + config = {} + + p_spec = Provider( + provider_id=pid, + provider_type=provider.provider_type, + config=config, + module=provider.module, + ) + run_config.providers[api].append(p_spec) + + run_config_file = build_dir / f"{image_name}-run.yaml" + + with open(run_config_file, "w") as f: + to_write = json.loads(run_config.model_dump_json()) + f.write(yaml.dump(to_write, sort_keys=False)) + + # Only print this message for non-container builds since it will be displayed before the + # container is built + # For non-container builds, the run.yaml is generated at the very end of the build process so it + # makes sense to display this message + if build_config.image_type != LlamaStackImageType.CONTAINER.value: + cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr) + return run_config_file + + +@lru_cache +def available_templates_specs() -> dict[str, BuildConfig]: + import yaml + + template_specs = {} + for p in TEMPLATES_PATH.rglob("*build.yaml"): + template_name = p.parent.name + with open(p) as f: + build_config = BuildConfig(**yaml.safe_load(f)) + template_specs[template_name] = build_config + return template_specs diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index 4b20588fd..5586b1dd8 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -7,6 +7,8 @@ import importlib.resources import logging import sys +import tomllib +from pathlib import Path from pydantic import BaseModel from termcolor import cprint @@ -72,8 +74,13 @@ def get_provider_dependencies( external_provider_deps.append(provider_spec.module) else: external_provider_deps.extend(provider_spec.module) - if hasattr(provider_spec, "pip_packages"): - deps.extend(provider_spec.pip_packages) + + pyproject = Path(provider_spec.module.replace(".", "/")) / "pyproject.toml" + with open(pyproject, "rb") as f: + data = tomllib.load(f) + + dependencies = data.get("project", {}).get("dependencies", []) + deps.extend(dependencies) if hasattr(provider_spec, "container_image") and provider_spec.container_image: raise ValueError("A stack's dependencies cannot have a container image") diff --git a/tests/unit/distribution/test_build_path.py b/tests/unit/distribution/test_build_path.py index 52a71286b..b4094618e 100644 --- a/tests/unit/distribution/test_build_path.py +++ b/tests/unit/distribution/test_build_path.py @@ -6,7 +6,7 @@ from pathlib import Path -from llama_stack.cli.stack._build import ( +from llama_stack.cli.stack._sync import ( _run_stack_build_command_from_build_config, ) from llama_stack.core.datatypes import BuildConfig, DistributionSpec