mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Distribution server now functioning
This commit is contained in:
parent
041cafbee3
commit
2cf9915806
21 changed files with 635 additions and 266 deletions
|
@ -18,11 +18,13 @@ from termcolor import cprint
|
|||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
from llama_toolchain.distribution.registry import (
|
||||
available_distributions,
|
||||
resolve_distribution,
|
||||
)
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
from .utils import run_command
|
||||
|
||||
DISTRIBS = available_distributions()
|
||||
from .utils import run_command
|
||||
|
||||
|
||||
class DistributionConfigure(Subcommand):
|
||||
|
@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Mame of the distribution to configure",
|
||||
help="Name of the distribution to configure",
|
||||
default="local-source",
|
||||
choices=[d.name for d in available_distributions()],
|
||||
)
|
||||
|
||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import argparse
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import resolve_distribution
|
||||
|
||||
|
||||
class DistributionCreate(Subcommand):
|
||||
|
@ -23,7 +24,20 @@ class DistributionCreate(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_distribution_create_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
pass
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the distribution to create",
|
||||
required=True,
|
||||
)
|
||||
# for each ApiSurface the user wants to support, we should
|
||||
# get the list of available adapters, ask which one the user
|
||||
# wants to pick and then ask for their configuration.
|
||||
|
||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is not None:
|
||||
self.parser.error(f"Distribution with name {args.name} already exists")
|
||||
return
|
||||
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -12,6 +12,7 @@ from .configure import DistributionConfigure
|
|||
from .create import DistributionCreate
|
||||
from .install import DistributionInstall
|
||||
from .list import DistributionList
|
||||
from .start import DistributionStart
|
||||
|
||||
|
||||
class DistributionParser(Subcommand):
|
||||
|
@ -31,3 +32,4 @@ class DistributionParser(Subcommand):
|
|||
DistributionInstall.create(subparsers)
|
||||
DistributionCreate.create(subparsers)
|
||||
DistributionConfigure.create(subparsers)
|
||||
DistributionStart.create(subparsers)
|
||||
|
|
|
@ -11,9 +11,13 @@ import shlex
|
|||
import pkg_resources
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import (
|
||||
available_distributions,
|
||||
resolve_distribution,
|
||||
)
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
|
||||
from .utils import run_command, run_with_pty
|
||||
|
||||
DISTRIBS = available_distributions()
|
||||
|
@ -55,12 +59,7 @@ class DistributionInstall(Subcommand):
|
|||
"distribution/install_distribution.sh",
|
||||
)
|
||||
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
|
|
@ -9,7 +9,7 @@ import argparse
|
|||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.cli.table import print_table
|
||||
|
||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
|
||||
|
||||
|
|
87
llama_toolchain/cli/distribution/start.py
Normal file
87
llama_toolchain/cli/distribution/start.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import resolve_distribution
|
||||
from llama_toolchain.distribution.server import main as distribution_server_init
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
|
||||
from .utils import run_command
|
||||
|
||||
|
||||
class DistributionStart(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"start",
|
||||
prog="llama distribution start",
|
||||
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_distribution_start_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the distribution to start",
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. Defaults to 5000",
|
||||
default=5000,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--disable-ipv6",
|
||||
action="store_true",
|
||||
help="Disable IPv6 support",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Distribution with name {args.name} not found")
|
||||
return
|
||||
|
||||
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
||||
if not config_yaml.exists():
|
||||
raise ValueError(
|
||||
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
|
||||
)
|
||||
|
||||
with open(config_yaml, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
conda_env = config["conda_env"]
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
# simple check, unfortunate
|
||||
if conda_env not in python_exe:
|
||||
raise ValueError(
|
||||
f"Please re-run start after activating the `{conda_env}` conda environment first"
|
||||
)
|
||||
|
||||
distribution_server_init(
|
||||
dist.name,
|
||||
config_yaml,
|
||||
args.port,
|
||||
disable_ipv6=args.disable_ipv6,
|
||||
)
|
||||
# run_with_pty(
|
||||
# shlex.split(
|
||||
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
|
||||
# )
|
||||
# )
|
|
@ -8,21 +8,35 @@ import errno
|
|||
import os
|
||||
import pty
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import termios
|
||||
import tty
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
|
||||
# Create a new pseudo-terminal
|
||||
master, slave = pty.openpty()
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
ctrl_c_pressed = True
|
||||
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
||||
|
||||
try:
|
||||
# ensure the terminal does not echo input
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
|
@ -30,26 +44,36 @@ def run_with_pty(command):
|
|||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while True:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [])
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data: # EOF
|
||||
break
|
||||
os.write(master, data)
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(sys.stdout.fileno(), data)
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
|
@ -58,11 +82,14 @@ def run_with_pty(command):
|
|||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Restore original terminal settings
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
process.wait()
|
||||
os.close(master)
|
||||
os.close(master)
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ import argparse
|
|||
|
||||
from .distribution import DistributionParser
|
||||
from .download import Download
|
||||
from .inference import InferenceParser
|
||||
|
||||
# from .inference import InferenceParser
|
||||
from .model import ModelParser
|
||||
|
||||
|
||||
|
@ -29,7 +30,7 @@ class LlamaCLIParser:
|
|||
|
||||
# Add sub-commands
|
||||
Download.create(subparsers)
|
||||
InferenceParser.create(subparsers)
|
||||
# InferenceParser.create(subparsers)
|
||||
ModelParser.create(subparsers)
|
||||
DistributionParser.create(subparsers)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue