bunch more work to make adapters work

This commit is contained in:
Ashwin Bharambe 2024-08-27 19:15:42 -07:00
parent 68f3db62e9
commit c4fe72c3a3
20 changed files with 461 additions and 173 deletions

View file

@ -9,6 +9,7 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import ApiBuild
from .configure import ApiConfigure
class ApiParser(Subcommand):
@ -24,3 +25,4 @@ class ApiParser(Subcommand):
# Add sub-commands
ApiBuild.create(subparsers)
ApiConfigure.create(subparsers)

View file

@ -6,22 +6,92 @@
import argparse
import os
import random
import string
import uuid
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Tuple
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import DISTRIBS_BASE_DIR
from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
def random_string():
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image
else:
if provider.adapter:
return provider.adapter.pip_packages, None
return [], None
pip_packages, docker_image = _deps(provider)
for dep in dependencies.values():
dep_pip_packages, dep_docker_image = _deps(dep)
if docker_image and dep_docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep_pip_packages)
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.distribution.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in dependencies.split(","):
dep = dep.strip()
if not dep:
continue
api_str, provider = dep.split("=")
api = Api(api_str)
provider = provider.strip()
if provider not in all_providers[api]:
parser.error(f"Provider `{provider}` is not available for API `{api}`")
return
deps[api] = all_providers[api][provider]
return deps
class ApiBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"install",
"build",
prog="llama api build",
description="Build a Llama stack API provider container",
formatter_class=argparse.RawTextHelpFormatter,
@ -36,7 +106,7 @@ class ApiBuild(Subcommand):
self.parser.add_argument(
"api",
choices=allowed_args,
help="Stack API (one of: {})".format(", ".join(allowed_args))
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
@ -45,73 +115,104 @@ class ApiBuild(Subcommand):
help="The provider to package into the container",
required=True,
)
self.parser.add_argument(
"--container-name",
type=str,
help="Name of the container (including tag if needed)",
required=True,
)
self.parser.add_argument(
"--dependencies",
type=str,
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
required=False,
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env). Defaults to a random UUID",
required=False,
)
self.parser.add_argument(
"--type",
type=str,
default="container",
choices=[v.value for v in BuildType],
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.datatypes import DistributionConfig
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import resolve_distribution_spec
from llama_toolchain.distribution.distribution import api_providers
os.makedirs(DISTRIBS_BASE_DIR, exist_ok=True)
script = pkg_resources.resource_filename(
"llama_toolchain",
"distribution/build_api.sh",
)
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
all_providers = api_providers()
dist = resolve_distribution_spec(args.spec)
if dist is None:
self.parser.error(f"Could not find distribution {args.spec}")
api = Api(args.api)
assert api in all_providers
providers = all_providers[api]
if args.provider not in providers:
self.parser.error(
f"Provider `{args.provider}` is not available for API `{api}`"
)
return
distrib_dir = DISTRIBS_BASE_DIR / args.name
os.makedirs(distrib_dir, exist_ok=True)
deps = distribution_dependencies(dist)
if not args.conda_env:
print(f"Using {args.name} as the Conda environment for this distribution")
conda_env = args.conda_env or args.name
config_file = distrib_dir / "config.yaml"
if config_file.exists():
c = DistributionConfig(**yaml.safe_load(config_file.read_text()))
if c.spec != dist.spec_id:
self.parser.error(
f"already installed distribution with `spec={c.spec}` does not match provided spec `{args.spec}`"
)
return
if c.conda_env != conda_env:
self.parser.error(
f"already installed distribution has `conda_env={c.conda_env}` different from provided conda env `{conda_env}`"
)
return
name = args.name or random_string()
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{name}"
else:
with open(config_file, "w") as f:
c = DistributionConfig(
spec=dist.spec_id,
name=args.name,
conda_env=conda_env,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
package_name = f"env-{args.provider}-{name}"
package_name = package_name.replace("::", "-")
return_code = run_with_pty([script, conda_env, args.name, " ".join(deps)])
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
package_file = build_dir / f"{package_name}.yaml"
stub_config = {
api.value: {
"provider_id": args.provider,
},
**{k: {"provider_id": v} for k, v in provider_deps.items()},
}
with open(package_file, "w") as f:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
docker_image=(
package_name if args.type == BuildType.container.value else None
),
conda_env=(
package_name if args.type == BuildType.conda_env.value else None
),
providers=stub_config,
)
f.write(yaml.dump(c.dict(), sort_keys=False))
if args.type == BuildType.container.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
args.api,
package_name,
dependencies.docker_image or "python:3.10-slim",
" ".join(dependencies.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
args.api,
package_name,
" ".join(dependencies.pip_packages),
]
return_code = run_with_pty(args)
assert return_code == 0, cprint(
f"Failed to install distribution {dist.spec_id}", color="red"
f"Failed to build target {package_name}", color="red"
)
cprint(
f"Distribution `{args.name}` (with spec {dist.spec_id}) has been installed successfully!",
f"Target `{target_name}` built with configuration at {str(package_file)}",
color="green",
)

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from pathlib import Path
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
from termcolor import cprint
class ApiConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama api configure",
description="configure a llama stack API provider",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_configure_cmd)
def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
"api",
choices=allowed_args,
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the provider build to fully configure",
required=True,
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml"
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
configure_llama_provider(config_file)
def configure_llama_provider(config_file: Path) -> None:
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.distribution.dynamic import instantiate_class_type
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
all_providers = api_providers()
provider_configs = {}
for api, stub_config in config.providers.items():
providers = all_providers[Api(api)]
provider_id = stub_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_spec = providers[provider_id]
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"])
config_type = instantiate_class_type(provider_spec.config_class)
print(f"Config type: {config_type}")
provider_config = prompt_for_config(
config_type,
)
print("")
provider_configs[api.value] = {
"provider_id": provider_id,
**provider_config.dict(),
}
config.providers = provider_configs
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
print(f"YAML configuration has been written to {config_path}")