Some quick fixes to the CLI behavior to make it consistent

This commit is contained in:
Ashwin Bharambe 2024-08-28 17:17:46 -07:00
parent f1244f6d9e
commit 3063329dad
5 changed files with 74 additions and 17 deletions

View file

@ -81,7 +81,7 @@ class ApiBuild(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--type", "--type",
type=str, type=str,
default="container", default="conda_env",
choices=[v.value for v in BuildType], choices=[v.value for v in BuildType],
) )

View file

@ -32,6 +32,7 @@ class ApiConfigure(Subcommand):
def _add_arguments(self): def _add_arguments(self):
from llama_toolchain.distribution.distribution import stack_apis from llama_toolchain.distribution.distribution import stack_apis
from llama_toolchain.distribution.package import BuildType
allowed_args = [a.name for a in stack_apis()] allowed_args = [a.name for a in stack_apis()]
self.parser.add_argument( self.parser.add_argument(
@ -42,15 +43,41 @@ class ApiConfigure(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--build-name", "--build-name",
type=str, type=str,
help="Name of the provider build to fully configure", help="(Fully qualified) name of the API build to configure. Alternatively, specify the --provider and --name options.",
required=True, required=False,
)
self.parser.add_argument(
"--provider",
type=str,
help="The provider chosen for the API",
required=False,
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=False,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
) )
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.build_name from llama_toolchain.distribution.package import BuildType
if not name.endswith(".yaml"):
name += ".yaml" if args.build_name:
config_file = BUILDS_BASE_DIR / args.api / name name = args.build_name
if name.endswith(".yaml"):
name = name.replace(".yaml", "")
else:
build_type = BuildType(args.type)
name = f"{build_type.descriptor()}-{args.provider}-{args.name}"
config_file = BUILDS_BASE_DIR / args.api / f"{name}.yaml"
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first" f"Could not find {config_file}. Please run `llama api build` first"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import * # noqa: F403 from llama_toolchain.distribution.datatypes import * # noqa: F403
@ -46,7 +45,7 @@ class StackBuild(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--type", "--type",
type=str, type=str,
default="container", default="conda_env",
choices=[v.value for v in BuildType], choices=[v.value for v in BuildType],
) )

View file

@ -31,19 +31,48 @@ class StackConfigure(Subcommand):
self.parser.set_defaults(func=self._run_stack_configure_cmd) self.parser.set_defaults(func=self._run_stack_configure_cmd)
def _add_arguments(self): def _add_arguments(self):
from llama_toolchain.distribution.package import BuildType
from llama_toolchain.distribution.registry import available_distribution_specs
self.parser.add_argument( self.parser.add_argument(
"--build-name", "--build-name",
type=str, type=str,
help="Name of the stack build to configure", help="(Fully qualified) name of the stack build to configure. Alternatively, provider --distribution and --name",
required=True, required=False,
)
allowed_ids = [d.distribution_id for d in available_distribution_specs()]
self.parser.add_argument(
"--distribution",
type=str,
choices=allowed_ids,
help="Distribution (one of: {})".format(allowed_ids),
required=False,
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build",
required=False,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
) )
def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
name = args.build_name from llama_toolchain.distribution.package import BuildType
if not name.endswith(".yaml"):
name += ".yaml"
config_file = BUILDS_BASE_DIR / "stack" / name if args.build_name:
name = args.build_name
if name.endswith(".yaml"):
name = name.replace(".yaml", "")
else:
build_type = BuildType(args.type)
name = f"{build_type.descriptor()}-{args.distribution}-{args.name}"
config_file = BUILDS_BASE_DIR / "stack" / f"{name}.yaml"
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
f"Could not find {config_file}. Please run `llama stack build` first" f"Could not find {config_file}. Please run `llama stack build` first"

View file

@ -28,6 +28,9 @@ class BuildType(Enum):
container = "container" container = "container"
conda_env = "conda_env" conda_env = "conda_env"
def descriptor(self) -> str:
return "image" if self == self.container else "env"
class Dependencies(BaseModel): class Dependencies(BaseModel):
pip_packages: List[str] pip_packages: List[str]
@ -77,12 +80,11 @@ def build_package(
provider = distribution_id if is_stack else api1.provider provider = distribution_id if is_stack else api1.provider
api_or_stack = "stack" if is_stack else api1.api.value api_or_stack = "stack" if is_stack else api1.api.value
build_desc = "image" if build_type == BuildType.container else "env"
build_dir = BUILDS_BASE_DIR / api_or_stack build_dir = BUILDS_BASE_DIR / api_or_stack
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
package_name = f"{build_desc}-{provider}-{name}" package_name = f"{build_type.descriptor()}-{provider}-{name}"
package_name = package_name.replace("::", "-") package_name = package_name.replace("::", "-")
package_file = build_dir / f"{package_name}.yaml" package_file = build_dir / f"{package_name}.yaml"