Simplify and generalize llama api build yay

This commit is contained in:
Ashwin Bharambe 2024-08-30 14:51:40 -07:00
parent 297d51b183
commit f8517e4688
9 changed files with 103 additions and 151 deletions

View file

@ -11,15 +11,15 @@ from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
def parse_dependencies(
dependencies: str, parser: argparse.ArgumentParser
def parse_api_provider_tuples(
tuples: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.core.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in dependencies.split(","):
for dep in tuples.split(","):
dep = dep.strip()
if not dep:
continue
@ -48,29 +48,13 @@ class ApiBuild(Subcommand):
self.parser.set_defaults(func=self._run_api_build_command)
def _add_arguments(self):
from llama_toolchain.core.distribution import stack_apis
from llama_toolchain.core.package import (
BuildType,
)
allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument(
"api",
choices=allowed_args,
help="Stack API (one of: {})".format(", ".join(allowed_args)),
)
self.parser.add_argument(
"--provider",
type=str,
help="The provider to package into the container",
required=True,
)
self.parser.add_argument(
"--dependencies",
type=str,
help="Comma separated list of (downstream_api=provider) dependencies needed for the API",
required=False,
"api_providers",
help="Comma separated list of (api=provider) tuples",
)
self.parser.add_argument(
"--name",
@ -92,14 +76,23 @@ class ApiBuild(Subcommand):
build_package,
)
api_input = ApiInput(
api=Api(args.api),
provider=args.provider,
dependencies=parse_dependencies(args.dependencies or "", self.parser),
)
parsed = parse_api_provider_tuples(args.api_providers, self.parser)
api_inputs = []
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(f"API {api} needs dependency {dep} provided also")
return
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_id,
)
)
build_package(
[api_input],
api_inputs,
build_type=BuildType(args.type),
name=args.name,
)