Remote llama api [] subcommands

This commit is contained in:
Ashwin Bharambe 2024-09-02 18:48:19 -07:00
parent 9be0edc76c
commit 5927f3c3c0
5 changed files with 0 additions and 313 deletions

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import ApiParser # noqa

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from .build import ApiBuild
from .configure import ApiConfigure
from .start import ApiStart
class ApiParser(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"api",
prog="llama api",
description="Operate on llama stack API providers",
)
subparsers = self.parser.add_subparsers(title="api_subcommands")
# Add sub-commands
ApiBuild.create(subparsers)
ApiConfigure.create(subparsers)
ApiStart.create(subparsers)

View file

@ -1,98 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from typing import Dict
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
def parse_api_provider_tuples(
tuples: str, parser: argparse.ArgumentParser
) -> Dict[str, ProviderSpec]:
from llama_toolchain.core.distribution import api_providers
all_providers = api_providers()
deps = {}
for dep in tuples.split(","):
dep = dep.strip()
if not dep:
continue
api_str, provider = dep.split("=")
api = Api(api_str)
provider = provider.strip()
if provider not in all_providers[api]:
parser.error(f"Provider `{provider}` is not available for API `{api}`")
return
deps[api] = all_providers[api][provider]
return deps
class ApiBuild(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"build",
prog="llama api build",
description="Build a Llama stack API provider container",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_build_command)
def _add_arguments(self):
from llama_toolchain.core.package import (
BuildType,
)
self.parser.add_argument(
"api_providers",
help="Comma separated list of (api=provider) tuples",
)
self.parser.add_argument(
"--name",
type=str,
help="Name of the build target (image, conda env)",
required=True,
)
self.parser.add_argument(
"--type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
def _run_api_build_command(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import (
ApiInput,
BuildType,
build_package,
)
parsed = parse_api_provider_tuples(args.api_providers, self.parser)
api_inputs = []
for api, provider_spec in parsed.items():
for dep in provider_spec.api_dependencies:
if dep not in parsed:
self.parser.error(f"API {api} needs dependency {dep} provided also")
return
api_inputs.append(
ApiInput(
api=api,
provider=provider_spec.provider_id,
)
)
build_package(
api_inputs,
build_type=BuildType(args.type),
name=args.name,
)

View file

@ -1,79 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
from pathlib import Path
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
from llama_toolchain.core.datatypes import * # noqa: F403
class ApiConfigure(Subcommand):
"""Llama cli for configuring llama toolchain configs"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"configure",
prog="llama api configure",
description="configure a llama stack API provider",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_configure_cmd)
def _add_arguments(self):
from llama_toolchain.core.package import BuildType
self.parser.add_argument(
"--build-name",
type=str,
help="Name of the build",
required=True,
)
self.parser.add_argument(
"--build-type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.core.package import BuildType
build_type = BuildType(args.build_type)
name = args.build_name
config_file = (
BUILDS_BASE_DIR / "adhoc" / build_type.descriptor() / f"{name}.yaml"
)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
configure_llama_provider(config_file)
def configure_llama_provider(config_file: Path) -> None:
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.core.configure import configure_api_providers
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
config.providers = configure_api_providers(config.providers)
with open(config_file, "w") as fp:
to_write = json.loads(json.dumps(config.dict(), cls=EnumEncoder))
fp.write(yaml.dump(to_write, sort_keys=False))
print(f"YAML configuration has been written to {config_file}")

View file

@ -1,99 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
import pkg_resources
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.core.datatypes import * # noqa: F403
from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
class ApiStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama api start",
description="""start the server for a Llama API provider. You should have already built and configured the provider.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_api_start_cmd)
def _add_arguments(self):
from llama_toolchain.core.package import BuildType
self.parser.add_argument(
"build_name",
type=str,
help="Name of the API build you want to start",
)
self.parser.add_argument(
"--build-type",
type=str,
default="conda_env",
choices=[v.value for v in BuildType],
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_api_start_cmd(self, args: argparse.Namespace) -> None:
from llama_toolchain.common.exec import run_with_pty
from llama_toolchain.core.package import BuildType
if args.build_name.endswith(".yaml"):
path = args.build_name
else:
build_type = BuildType(args.build_type)
build_dir = BUILDS_BASE_DIR / "adhoc" / build_type.descriptor()
path = build_dir / f"{args.build_name}.yaml"
config_file = Path(path)
if not config_file.exists():
self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first"
)
return
with open(config_file, "r") as f:
config = PackageConfig(**yaml.safe_load(f))
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_container.sh",
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_toolchain",
"core/start_conda_env.sh",
)
run_args = [
script,
config.conda_env,
]
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
run_with_pty(run_args)