Allow building an "adhoc" distribution

This commit is contained in:
Ashwin Bharambe 2024-09-02 18:37:31 -07:00
parent d99c06fce8
commit 9be0edc76c

View file

@ -10,6 +10,30 @@ from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403 from llama_toolchain.core.datatypes import * # noqa: F403
def parse_api_provider_tuples(
tuples: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.core.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in tuples.split(","):
dep = dep.strip()
if not dep:
continue
api_str, provider = dep.split("=")
api = Api(api_str)
provider = provider.strip()
if provider not in all_providers[api]:
parser.error(f"Provider `{provider}` is not available for API `{api}`")
return
deps[api] = all_providers[api][provider]
return deps
class StackBuild(Subcommand): class StackBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__() super().__init__()
@ -32,8 +56,12 @@ class StackBuild(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"distribution", "distribution",
type=str, type=str,
choices=allowed_ids, help="Distribution to build (either \"adhoc\" OR one of: {})".format(allowed_ids),
help="Distribution to build (one of: {})".format(allowed_ids), )
self.parser.add_argument(
"api_providers",
nargs='?',
help="Comma separated list of (api=provider) tuples",
) )
self.parser.add_argument( self.parser.add_argument(
@ -57,24 +85,49 @@ class StackBuild(Subcommand):
build_package, build_package,
) )
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
api_inputs = [] api_inputs = []
for api, provider_id in dist.providers.items(): if args.distribution == "adhoc":
api_inputs.append( if not args.api_providers:
ApiInput( self.parser.error("You must specify API providers with (api=provider,...) for building an adhoc distribution")
api=api, return
provider=provider_id,
parsed = parse_api_provider_tuples(args.api_providers, self.parser)
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(f"API {api} needs dependency {dep} provided also")
return
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_id,
)
) )
) docker_image = None
else:
if args.api_providers:
self.parser.error("You cannot specify API providers for pre-registered distributions")
return
dist = resolve_distribution_spec(args.distribution)
if dist is None:
self.parser.error(f"Could not find distribution {args.distribution}")
return
for api, provider_id in dist.providers.items():
api_inputs.append(
ApiInput(
api=api,
provider=provider_id,
)
)
docker_image = dist.docker_image
build_package( build_package(
api_inputs, api_inputs,
build_type=BuildType(args.type), build_type=BuildType(args.type),
name=args.name, name=args.name,
distribution_id=args.distribution, distribution_id=args.distribution,
docker_image=dist.docker_image, docker_image=docker_image,
) )