From 9be0edc76c186430fb2fa82bc524a4fc7094c7dd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 2 Sep 2024 18:37:31 -0700 Subject: [PATCH] Allow building an "adhoc" distribution --- llama_toolchain/cli/stack/build.py | 81 ++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 14 deletions(-) diff --git a/llama_toolchain/cli/stack/build.py b/llama_toolchain/cli/stack/build.py index fd7511bda..c81a6d350 100644 --- a/llama_toolchain/cli/stack/build.py +++ b/llama_toolchain/cli/stack/build.py @@ -10,6 +10,30 @@ from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.core.datatypes import * # noqa: F403 +def parse_api_provider_tuples( + tuples: str, parser: argparse.ArgumentParser +) -> Dict[str, ProviderSpec]: + from llama_toolchain.core.distribution import api_providers + + all_providers = api_providers() + + deps = {} + for dep in tuples.split(","): + dep = dep.strip() + if not dep: + continue + api_str, provider = dep.split("=") + api = Api(api_str) + + provider = provider.strip() + if provider not in all_providers[api]: + parser.error(f"Provider `{provider}` is not available for API `{api}`") + return + deps[api] = all_providers[api][provider] + + return deps + + class StackBuild(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): super().__init__() @@ -32,8 +56,12 @@ class StackBuild(Subcommand): self.parser.add_argument( "distribution", type=str, - choices=allowed_ids, - help="Distribution to build (one of: {})".format(allowed_ids), + help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids), + ) + self.parser.add_argument( + "api_providers", + nargs='?', + help="Comma separated list of (api=provider) tuples", ) self.parser.add_argument( @@ -57,24 +85,49 @@ class StackBuild(Subcommand): build_package, ) - dist = resolve_distribution_spec(args.distribution) - if dist is None: - self.parser.error(f"Could not find distribution {args.distribution}") - return - api_inputs = [] - for api, provider_id in dist.providers.items(): - api_inputs.append( - ApiInput( - api=api, - provider=provider_id, + if args.distribution == "adhoc": + if not args.api_providers: + self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution") + return + + parsed = parse_api_provider_tuples(args.api_providers, self.parser) + for api, provider_spec in parsed.items(): + for dep in provider_spec.api_dependencies: + if dep not in parsed: + self.parser.error(f"API {api} needs dependency {dep} provided also") + return + + api_inputs.append( + ApiInput( + api=api, + provider=provider_spec.provider_id, + ) ) - ) + docker_image = None + else: + if args.api_providers: + self.parser.error("You cannot specify API providers for pre-registered distributions") + return + + dist = resolve_distribution_spec(args.distribution) + if dist is None: + self.parser.error(f"Could not find distribution {args.distribution}") + return + + for api, provider_id in dist.providers.items(): + api_inputs.append( + ApiInput( + api=api, + provider=provider_id, + ) + ) + docker_image = dist.docker_image build_package( api_inputs, build_type=BuildType(args.type), name=args.name, distribution_id=args.distribution, - docker_image=dist.docker_image, + docker_image=docker_image, )