All the new CLI for api + stack work

This commit is contained in:
Ashwin Bharambe 2024-08-28 15:52:49 -07:00
parent fd3b65b718
commit 197f768636
16 changed files with 459 additions and 486 deletions

View file

@ -5,52 +5,12 @@
# the root directory of this source tree.
import argparse
import json
import os
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
import pkg_resources
import yaml
from termcolor import cprint
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.distribution.datatypes import * # noqa: F403
class BuildType(Enum):
container = "container"
conda_env = "conda_env"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
def get_dependencies(
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
pip_packages = provider.pip_packages
for dep in dependencies.values():
if dep.docker_image:
raise ValueError(
"You can only have the root provider specify a docker image"
)
pip_packages.extend(dep.pip_packages)
return Dependencies(
docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES,
)
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
@ -89,6 +49,9 @@ class ApiBuild(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis
from llama_toolchain.distribution.package import (
BuildType,
)
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
@ -123,101 +86,20 @@ class ApiBuild(Subcommand):
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.distribution.distribution import api_providers
from llama_toolchain.common.serialize import EnumEncoder
os.makedirs(BUILDS_BASE_DIR, exist_ok=True)
all_providers = api_providers()
api = Api(args.api)
assert api in all_providers
providers = all_providers[api]
if args.provider not in providers:
self.parser.error(
f"Provider `{args.provider}` is not available for API `{api}`"
)
return
if args.type == BuildType.container.value:
package_name = f"image-{args.provider}-{args.name}"
else:
package_name = f"env-{args.provider}-{args.name}"
package_name = package_name.replace("::", "-")
build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True)
# get these names straight. too confusing.
provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps)
package_file = build_dir / f"{package_name}.yaml"
stub_config = {
api.value: {
"provider_id": args.provider,
},
**provider_deps,
}
# properly handle the case where package exists but has
# inconsistent configuration for the providers. if possible,
# we don't want to overwrite the existing configuration.
if package_file.exists():
cprint(
f"Build `{package_name}` exists; will reconfigure",
color="yellow",
)
c = PackageConfig(**yaml.safe_load(package_file.read_text()))
else:
c = PackageConfig(
built_at=datetime.now(),
package_name=package_name,
providers=stub_config,
)
c.docker_image = (
package_name if args.type == BuildType.container.value else None
from llama_toolchain.distribution.package import (
ApiInput,
BuildType,
build_package,
)
c.conda_env = package_name if args.type == BuildType.conda_env.value else None
with open(package_file, "w") as f:
to_write = json.loads(json.dumps(c.dict(), cls=EnumEncoder))
f.write(yaml.dump(to_write, sort_keys=False))
if args.type == BuildType.container.value:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_container.sh"
)
args = [
script,
args.api,
package_name,
dependencies.docker_image or "python:3.10-slim",
" ".join(dependencies.pip_packages),
]
else:
script = pkg_resources.resource_filename(
"llama_toolchain", "distribution/build_conda_env.sh"
)
args = [
script,
args.api,
package_name,
" ".join(dependencies.pip_packages),
]
return_code = run_with_pty(args)
if return_code != 0:
cprint(
f"Failed to build target {package_name} with return code {return_code}",
color="red",
)
return
cprint(
f"Target `{package_name}` built with configuration at {str(package_file)}",
color="green",
api_input = ApiInput(
api=Api(args.api),
provider=args.provider,
dependencies=parse_dependencies(args.dependencies or "", self.parser),
)
build_package(
[api_input],
build_type=BuildType(args.type),
name=args.name,
)