Distribution server now functioning

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:37:40 -07:00
parent 041cafbee3
commit 2cf9915806
21 changed files with 635 additions and 266 deletions

View file

@ -18,11 +18,13 @@ from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
DISTRIBS = available_distributions()
from .utils import run_command
class DistributionConfigure(Subcommand):
@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand):
self.parser.add_argument(
"--name",
type=str,
help="Mame of the distribution to configure",
help="Name of the distribution to configure",
default="local-source",
choices=[d.name for d in available_distributions()],
)
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
dist = None
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
return

View file

@ -7,6 +7,7 @@
import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
class DistributionCreate(Subcommand):
@ -23,7 +24,20 @@ class DistributionCreate(Subcommand):
self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self):
pass
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each ApiSurface the user wants to support, we should
# get the list of available adapters, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
dist = resolve_distribution(args.name)
if dist is not None:
self.parser.error(f"Distribution with name {args.name} already exists")
return
raise NotImplementedError()

View file

@ -12,6 +12,7 @@ from .configure import DistributionConfigure
from .create import DistributionCreate
from .install import DistributionInstall
from .list import DistributionList
from .start import DistributionStart
class DistributionParser(Subcommand):
@ -31,3 +32,4 @@ class DistributionParser(Subcommand):
DistributionInstall.create(subparsers)
DistributionCreate.create(subparsers)
DistributionConfigure.create(subparsers)
DistributionStart.create(subparsers)

View file

@ -11,9 +11,13 @@ import shlex
import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty
DISTRIBS = available_distributions()
@ -55,12 +59,7 @@ class DistributionInstall(Subcommand):
"distribution/install_distribution.sh",
)
dist = None
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Could not find distribution {args.name}")
return

View file

@ -9,7 +9,7 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.datatypes import distribution_dependencies
from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shlex
from pathlib import Path
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.server import main as distribution_server_init
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
class DistributionStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to start",
required=True,
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Distribution with name {args.name} not found")
return
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if not config_yaml.exists():
raise ValueError(
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
)
with open(config_yaml, "r") as fp:
config = yaml.safe_load(fp)
conda_env = config["conda_env"]
python_exe = run_command(shlex.split("which python"))
# simple check, unfortunate
if conda_env not in python_exe:
raise ValueError(
f"Please re-run start after activating the `{conda_env}` conda environment first"
)
distribution_server_init(
dist.name,
config_yaml,
args.port,
disable_ipv6=args.disable_ipv6,
)
# run_with_pty(
# shlex.split(
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
# )
# )

View file

@ -8,21 +8,35 @@ import errno
import os
import pty
import select
import signal
import subprocess
import sys
import termios
import tty
from termcolor import cprint
def run_with_pty(command):
old_settings = termios.tcgetattr(sys.stdin)
# Create a new pseudo-terminal
master, slave = pty.openpty()
old_settings = termios.tcgetattr(sys.stdin)
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
try:
# ensure the terminal does not echo input
tty.setraw(sys.stdin.fileno())
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
command,
@ -30,26 +44,36 @@ def run_with_pty(command):
stdout=slave,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while True:
rlist, _, _ = select.select([sys.stdin, master], [], [])
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data: # EOF
break
os.write(master, data)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
os.write(sys.stdout.fileno(), data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
@ -58,11 +82,14 @@ def run_with_pty(command):
if e.errno != errno.EIO:
raise
finally:
# Restore original terminal settings
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
process.wait()
os.close(master)
os.close(master)
if process.poll() is None:
process.terminate()
process.wait()
return process.returncode

View file

@ -8,7 +8,8 @@ import argparse
from .distribution import DistributionParser
from .download import Download
from .inference import InferenceParser
# from .inference import InferenceParser
from .model import ModelParser
@ -29,7 +30,7 @@ class LlamaCLIParser:
# Add sub-commands
Download.create(subparsers)
InferenceParser.create(subparsers)
# InferenceParser.create(subparsers)
ModelParser.create(subparsers)
DistributionParser.create(subparsers)