mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Distribution server now functioning
This commit is contained in:
parent
041cafbee3
commit
2cf9915806
21 changed files with 635 additions and 266 deletions
|
@ -18,11 +18,13 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
|
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
from llama_toolchain.distribution.registry import (
|
||||||
|
available_distributions,
|
||||||
|
resolve_distribution,
|
||||||
|
)
|
||||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||||
from .utils import run_command
|
|
||||||
|
|
||||||
DISTRIBS = available_distributions()
|
from .utils import run_command
|
||||||
|
|
||||||
|
|
||||||
class DistributionConfigure(Subcommand):
|
class DistributionConfigure(Subcommand):
|
||||||
|
@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
help="Mame of the distribution to configure",
|
help="Name of the distribution to configure",
|
||||||
default="local-source",
|
default="local-source",
|
||||||
choices=[d.name for d in available_distributions()],
|
choices=[d.name for d in available_distributions()],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||||
dist = None
|
dist = resolve_distribution(args.name)
|
||||||
for d in DISTRIBS:
|
|
||||||
if d.name == args.name:
|
|
||||||
dist = d
|
|
||||||
break
|
|
||||||
|
|
||||||
if dist is None:
|
if dist is None:
|
||||||
self.parser.error(f"Could not find distribution {args.name}")
|
self.parser.error(f"Could not find distribution {args.name}")
|
||||||
return
|
return
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
|
||||||
|
|
||||||
class DistributionCreate(Subcommand):
|
class DistributionCreate(Subcommand):
|
||||||
|
@ -23,7 +24,20 @@ class DistributionCreate(Subcommand):
|
||||||
self.parser.set_defaults(func=self._run_distribution_create_cmd)
|
self.parser.set_defaults(func=self._run_distribution_create_cmd)
|
||||||
|
|
||||||
def _add_arguments(self):
|
def _add_arguments(self):
|
||||||
pass
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the distribution to create",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
# for each ApiSurface the user wants to support, we should
|
||||||
|
# get the list of available adapters, ask which one the user
|
||||||
|
# wants to pick and then ask for their configuration.
|
||||||
|
|
||||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
dist = resolve_distribution(args.name)
|
||||||
|
if dist is not None:
|
||||||
|
self.parser.error(f"Distribution with name {args.name} already exists")
|
||||||
|
return
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -12,6 +12,7 @@ from .configure import DistributionConfigure
|
||||||
from .create import DistributionCreate
|
from .create import DistributionCreate
|
||||||
from .install import DistributionInstall
|
from .install import DistributionInstall
|
||||||
from .list import DistributionList
|
from .list import DistributionList
|
||||||
|
from .start import DistributionStart
|
||||||
|
|
||||||
|
|
||||||
class DistributionParser(Subcommand):
|
class DistributionParser(Subcommand):
|
||||||
|
@ -31,3 +32,4 @@ class DistributionParser(Subcommand):
|
||||||
DistributionInstall.create(subparsers)
|
DistributionInstall.create(subparsers)
|
||||||
DistributionCreate.create(subparsers)
|
DistributionCreate.create(subparsers)
|
||||||
DistributionConfigure.create(subparsers)
|
DistributionConfigure.create(subparsers)
|
||||||
|
DistributionStart.create(subparsers)
|
||||||
|
|
|
@ -11,9 +11,13 @@ import shlex
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
from llama_toolchain.distribution.registry import (
|
||||||
|
available_distributions,
|
||||||
|
resolve_distribution,
|
||||||
|
)
|
||||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from .utils import run_command, run_with_pty
|
from .utils import run_command, run_with_pty
|
||||||
|
|
||||||
DISTRIBS = available_distributions()
|
DISTRIBS = available_distributions()
|
||||||
|
@ -55,12 +59,7 @@ class DistributionInstall(Subcommand):
|
||||||
"distribution/install_distribution.sh",
|
"distribution/install_distribution.sh",
|
||||||
)
|
)
|
||||||
|
|
||||||
dist = None
|
dist = resolve_distribution(args.name)
|
||||||
for d in DISTRIBS:
|
|
||||||
if d.name == args.name:
|
|
||||||
dist = d
|
|
||||||
break
|
|
||||||
|
|
||||||
if dist is None:
|
if dist is None:
|
||||||
self.parser.error(f"Could not find distribution {args.name}")
|
self.parser.error(f"Could not find distribution {args.name}")
|
||||||
return
|
return
|
||||||
|
|
|
@ -9,7 +9,7 @@ import argparse
|
||||||
from llama_toolchain.cli.subcommand import Subcommand
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
from llama_toolchain.cli.table import print_table
|
from llama_toolchain.cli.table import print_table
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||||
from llama_toolchain.distribution.registry import available_distributions
|
from llama_toolchain.distribution.registry import available_distributions
|
||||||
|
|
||||||
|
|
||||||
|
|
87
llama_toolchain/cli/distribution/start.py
Normal file
87
llama_toolchain/cli/distribution/start.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shlex
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.distribution.registry import resolve_distribution
|
||||||
|
from llama_toolchain.distribution.server import main as distribution_server_init
|
||||||
|
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
|
from .utils import run_command
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionStart(Subcommand):
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"start",
|
||||||
|
prog="llama distribution start",
|
||||||
|
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_distribution_start_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the distribution to start",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="Port to run the server on. Defaults to 5000",
|
||||||
|
default=5000,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--disable-ipv6",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable IPv6 support",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
dist = resolve_distribution(args.name)
|
||||||
|
if dist is None:
|
||||||
|
self.parser.error(f"Distribution with name {args.name} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
||||||
|
if not config_yaml.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_yaml, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
conda_env = config["conda_env"]
|
||||||
|
python_exe = run_command(shlex.split("which python"))
|
||||||
|
# simple check, unfortunate
|
||||||
|
if conda_env not in python_exe:
|
||||||
|
raise ValueError(
|
||||||
|
f"Please re-run start after activating the `{conda_env}` conda environment first"
|
||||||
|
)
|
||||||
|
|
||||||
|
distribution_server_init(
|
||||||
|
dist.name,
|
||||||
|
config_yaml,
|
||||||
|
args.port,
|
||||||
|
disable_ipv6=args.disable_ipv6,
|
||||||
|
)
|
||||||
|
# run_with_pty(
|
||||||
|
# shlex.split(
|
||||||
|
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
|
||||||
|
# )
|
||||||
|
# )
|
|
@ -8,21 +8,35 @@ import errno
|
||||||
import os
|
import os
|
||||||
import pty
|
import pty
|
||||||
import select
|
import select
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import termios
|
import termios
|
||||||
import tty
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
def run_with_pty(command):
|
def run_with_pty(command):
|
||||||
old_settings = termios.tcgetattr(sys.stdin)
|
|
||||||
|
|
||||||
# Create a new pseudo-terminal
|
|
||||||
master, slave = pty.openpty()
|
master, slave = pty.openpty()
|
||||||
|
|
||||||
|
old_settings = termios.tcgetattr(sys.stdin)
|
||||||
|
original_sigint = signal.getsignal(signal.SIGINT)
|
||||||
|
|
||||||
|
ctrl_c_pressed = False
|
||||||
|
|
||||||
|
def sigint_handler(signum, frame):
|
||||||
|
nonlocal ctrl_c_pressed
|
||||||
|
ctrl_c_pressed = True
|
||||||
|
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ensure the terminal does not echo input
|
# Set up the signal handler
|
||||||
tty.setraw(sys.stdin.fileno())
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
new_settings = termios.tcgetattr(sys.stdin)
|
||||||
|
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||||
|
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||||
|
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||||
|
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
|
@ -30,26 +44,36 @@ def run_with_pty(command):
|
||||||
stdout=slave,
|
stdout=slave,
|
||||||
stderr=slave,
|
stderr=slave,
|
||||||
universal_newlines=True,
|
universal_newlines=True,
|
||||||
|
preexec_fn=os.setsid,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Close the slave file descriptor as it's now owned by the subprocess
|
# Close the slave file descriptor as it's now owned by the subprocess
|
||||||
os.close(slave)
|
os.close(slave)
|
||||||
|
|
||||||
def handle_io():
|
def handle_io():
|
||||||
while True:
|
while not ctrl_c_pressed:
|
||||||
rlist, _, _ = select.select([sys.stdin, master], [], [])
|
try:
|
||||||
|
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||||
|
|
||||||
if sys.stdin in rlist:
|
if sys.stdin in rlist:
|
||||||
data = os.read(sys.stdin.fileno(), 1024)
|
data = os.read(sys.stdin.fileno(), 1024)
|
||||||
if not data: # EOF
|
if not data:
|
||||||
break
|
break
|
||||||
os.write(master, data)
|
os.write(master, data)
|
||||||
|
|
||||||
if master in rlist:
|
if master in rlist:
|
||||||
data = os.read(master, 1024)
|
data = os.read(master, 1024)
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
os.write(sys.stdout.fileno(), data)
|
sys.stdout.buffer.write(data)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# This will be raised when Ctrl+C is pressed
|
||||||
|
break
|
||||||
|
|
||||||
|
if process.poll() is not None:
|
||||||
|
break
|
||||||
|
|
||||||
handle_io()
|
handle_io()
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
@ -58,11 +82,14 @@ def run_with_pty(command):
|
||||||
if e.errno != errno.EIO:
|
if e.errno != errno.EIO:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Restore original terminal settings
|
# Clean up
|
||||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||||
|
signal.signal(signal.SIGINT, original_sigint)
|
||||||
|
|
||||||
process.wait()
|
os.close(master)
|
||||||
os.close(master)
|
if process.poll() is None:
|
||||||
|
process.terminate()
|
||||||
|
process.wait()
|
||||||
|
|
||||||
return process.returncode
|
return process.returncode
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,8 @@ import argparse
|
||||||
|
|
||||||
from .distribution import DistributionParser
|
from .distribution import DistributionParser
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .inference import InferenceParser
|
|
||||||
|
# from .inference import InferenceParser
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ class LlamaCLIParser:
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
InferenceParser.create(subparsers)
|
# InferenceParser.create(subparsers)
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
DistributionParser.create(subparsers)
|
DistributionParser.create(subparsers)
|
||||||
|
|
||||||
|
|
|
@ -5,54 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Union
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from strong_typing.schema import json_schema_type
|
from strong_typing.schema import json_schema_type
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AdapterType(Enum):
|
|
||||||
passthrough_api = "passthrough_api"
|
|
||||||
python_impl = "python_impl"
|
|
||||||
not_implemented = "not_implemented"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PassthroughApiAdapterConfig(BaseModel):
|
|
||||||
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
|
|
||||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
|
||||||
headers: Dict[str, str] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Headers (e.g., authorization) to send with the request",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PythonImplAdapterConfig(BaseModel):
|
|
||||||
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
|
|
||||||
adapter_id: str
|
|
||||||
kwargs: Dict[str, Any] = Field(
|
|
||||||
default_factory=dict, description="kwargs to pass to the entrypoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class NotImplementedAdapterConfig(BaseModel):
|
|
||||||
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
|
|
||||||
|
|
||||||
|
|
||||||
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
|
|
||||||
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
|
|
||||||
AdapterConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PassthroughApiAdapterConfig,
|
|
||||||
NotImplementedAdapterConfig,
|
|
||||||
PythonImplAdapterConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -61,6 +17,13 @@ class ApiSurface(Enum):
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ApiSurfaceEndpoint(BaseModel):
|
||||||
|
route: str
|
||||||
|
method: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Adapter(BaseModel):
|
class Adapter(BaseModel):
|
||||||
api_surface: ApiSurface
|
api_surface: ApiSurface
|
||||||
|
@ -108,13 +71,3 @@ class Distribution(BaseModel):
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Additional pip packages beyond those required by the adapters",
|
description="Additional pip packages beyond those required by the adapters",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
|
||||||
# only consider SourceAdapters when calculating dependencies
|
|
||||||
return [
|
|
||||||
dep
|
|
||||||
for adapter in distribution.adapters.values()
|
|
||||||
if isinstance(adapter, SourceAdapter)
|
|
||||||
for dep in adapter.pip_packages
|
|
||||||
] + distribution.additional_pip_packages
|
|
||||||
|
|
51
llama_toolchain/distribution/distribution.py
Normal file
51
llama_toolchain/distribution/distribution.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from llama_toolchain.inference.api.endpoints import Inference
|
||||||
|
from llama_toolchain.safety.api.endpoints import Safety
|
||||||
|
|
||||||
|
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||||
|
# only consider SourceAdapters when calculating dependencies
|
||||||
|
return [
|
||||||
|
dep
|
||||||
|
for adapter in distribution.adapters.values()
|
||||||
|
if isinstance(adapter, SourceAdapter)
|
||||||
|
for dep in adapter.pip_packages
|
||||||
|
] + distribution.additional_pip_packages
|
||||||
|
|
||||||
|
|
||||||
|
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
||||||
|
surfaces = {}
|
||||||
|
|
||||||
|
protocols = {
|
||||||
|
ApiSurface.inference: Inference,
|
||||||
|
ApiSurface.safety: Safety,
|
||||||
|
}
|
||||||
|
|
||||||
|
for surface, protocol in protocols.items():
|
||||||
|
endpoints = []
|
||||||
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
|
for name, method in protocol_methods:
|
||||||
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
webmethod = method.__webmethod__
|
||||||
|
route = webmethod.route
|
||||||
|
|
||||||
|
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||||
|
# annotation and write our own openapi generator
|
||||||
|
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
|
||||||
|
|
||||||
|
surfaces[surface] = endpoints
|
||||||
|
|
||||||
|
return surfaces
|
26
llama_toolchain/distribution/dynamic.py
Normal file
26
llama_toolchain/distribution/dynamic.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from .datatypes import SourceAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
# returns a class implementing the protocol corresponding to the ApiSurface
|
||||||
|
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
|
||||||
|
module = importlib.import_module(adapter.module)
|
||||||
|
|
||||||
|
config_type = instantiate_class_type(adapter.config_class)
|
||||||
|
config = config_type(**adapter_config)
|
||||||
|
return asyncio.run(module.get_adapter_impl(config))
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||||
|
|
||||||
|
@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]:
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_distribution(name: str) -> Optional[Distribution]:
|
||||||
|
for dist in available_distributions():
|
||||||
|
if dist.name == name:
|
||||||
|
return dist
|
||||||
|
return None
|
||||||
|
|
202
llama_toolchain/distribution/server.py
Normal file
202
llama_toolchain/distribution/server.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import signal
|
||||||
|
from collections.abc import (
|
||||||
|
AsyncGenerator as AsyncGeneratorABC,
|
||||||
|
AsyncIterator as AsyncIteratorABC,
|
||||||
|
)
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .datatypes import PassthroughApiAdapter
|
||||||
|
from .distribution import api_surface_endpoints
|
||||||
|
from .dynamic import instantiate_adapter
|
||||||
|
|
||||||
|
from .registry import resolve_distribution
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_iterator_type(typ):
|
||||||
|
if hasattr(typ, "__origin__"):
|
||||||
|
origin = typ.__origin__
|
||||||
|
if isinstance(origin, type):
|
||||||
|
return issubclass(
|
||||||
|
origin,
|
||||||
|
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return isinstance(
|
||||||
|
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sse_event(data: Any) -> str:
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.json()
|
||||||
|
else:
|
||||||
|
data = json.dumps(data)
|
||||||
|
|
||||||
|
return f"data: {data}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def passthrough(
|
||||||
|
request: Request,
|
||||||
|
downstream_url: str,
|
||||||
|
downstream_headers: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
client = httpx.AsyncClient()
|
||||||
|
|
||||||
|
headers = dict(request.headers)
|
||||||
|
headers.pop("host", None)
|
||||||
|
headers.update(downstream_headers or {})
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=downstream_url,
|
||||||
|
headers=headers,
|
||||||
|
content=body,
|
||||||
|
params=request.query_params,
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
response.iter_bytes(),
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_sigint(*args, **kwargs):
|
||||||
|
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for task in asyncio.all_tasks(loop):
|
||||||
|
task.cancel()
|
||||||
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
print("Starting up")
|
||||||
|
yield
|
||||||
|
print("Shutting down")
|
||||||
|
|
||||||
|
|
||||||
|
def create_dynamic_passthrough(
|
||||||
|
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||||
|
):
|
||||||
|
async def endpoint(request: Request):
|
||||||
|
return await passthrough(request, downstream_url, downstream_headers)
|
||||||
|
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
def create_dynamic_typed_route(func: Any):
|
||||||
|
hints = get_type_hints(func)
|
||||||
|
request_model = next(iter(hints.values()))
|
||||||
|
response_model = hints["return"]
|
||||||
|
|
||||||
|
is_streaming = is_async_iterator_type(response_model)
|
||||||
|
|
||||||
|
if is_streaming:
|
||||||
|
|
||||||
|
async def endpoint(request: request_model):
|
||||||
|
async def event_generator():
|
||||||
|
async for item in func(request):
|
||||||
|
yield create_sse_event(item)
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
async def endpoint(request: request_model):
|
||||||
|
return func(request)
|
||||||
|
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||||
|
):
|
||||||
|
dist = resolve_distribution(dist_name)
|
||||||
|
if dist is None:
|
||||||
|
raise ValueError(f"Could not find distribution {dist_name}")
|
||||||
|
|
||||||
|
with open(yaml_config, "r") as fp:
|
||||||
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
all_endpoints = api_surface_endpoints()
|
||||||
|
|
||||||
|
adapter_configs = config["adapters"]
|
||||||
|
for surface, adapter in dist.adapters.items():
|
||||||
|
if surface.value not in adapter_configs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_config = adapter_configs[surface.value]
|
||||||
|
endpoints = all_endpoints[surface]
|
||||||
|
if isinstance(adapter, PassthroughApiAdapter):
|
||||||
|
for endpoint in endpoints:
|
||||||
|
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||||
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
|
create_dynamic_passthrough(url)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
impl = instantiate_adapter(adapter, adapter_config)
|
||||||
|
for endpoint in endpoints:
|
||||||
|
if not hasattr(impl, endpoint.name):
|
||||||
|
# ideally this should be a typing violation already
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find method {endpoint.name} on {impl}!!"
|
||||||
|
)
|
||||||
|
|
||||||
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
|
create_dynamic_typed_route(impl_method)
|
||||||
|
)
|
||||||
|
|
||||||
|
for route in app.routes:
|
||||||
|
if isinstance(route, APIRoute):
|
||||||
|
cprint(
|
||||||
|
f"Serving {next(iter(route.methods))} {route.path}",
|
||||||
|
"white",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, handle_sigint)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# FYI this does not do hot-reloads
|
||||||
|
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||||
|
print(f"Listening on {listen_host}:{port}")
|
||||||
|
uvicorn.run(app, host=listen_host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
|
@ -19,7 +19,7 @@ def available_inference_adapters() -> List[Adapter]:
|
||||||
"zmq",
|
"zmq",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.inference.inference",
|
module="llama_toolchain.inference.inference",
|
||||||
config_class="llama_toolchain.inference.inference.InlineImplConfig",
|
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
SourceAdapter(
|
SourceAdapter(
|
||||||
api_surface=ApiSurface.inference,
|
api_surface=ApiSurface.inference,
|
||||||
|
|
|
@ -7,9 +7,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from hydra.core.config_store import ConfigStore
|
|
||||||
|
|
||||||
from hydra_zen import builds
|
|
||||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -19,13 +16,6 @@ from typing_extensions import Annotated
|
||||||
from .datatypes import QuantizationConfig
|
from .datatypes import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ImplType(Enum):
|
|
||||||
inline = "inline"
|
|
||||||
remote = "remote"
|
|
||||||
ollama = "ollama"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CheckpointType(Enum):
|
class CheckpointType(Enum):
|
||||||
pytorch = "pytorch"
|
pytorch = "pytorch"
|
||||||
|
@ -66,8 +56,8 @@ class ModelCheckpointConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InlineImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
model: str
|
||||||
checkpoint_config: ModelCheckpointConfig
|
checkpoint_config: ModelCheckpointConfig
|
||||||
quantization: Optional[QuantizationConfig] = None
|
quantization: Optional[QuantizationConfig] = None
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
|
@ -75,28 +65,7 @@ class InlineImplConfig(BaseModel):
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RemoteImplConfig(BaseModel):
|
|
||||||
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
|
||||||
url: str = Field(..., description="The URL of the remote module")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OllamaImplConfig(BaseModel):
|
class OllamaImplConfig(BaseModel):
|
||||||
impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
|
|
||||||
model: str = Field(..., description="The name of the model in ollama catalog")
|
model: str = Field(..., description="The name of the model in ollama catalog")
|
||||||
url: str = Field(..., description="The URL for the ollama server")
|
url: str = Field(..., description="The URL for the ollama server")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class InferenceConfig(BaseModel):
|
|
||||||
impl_config: Annotated[
|
|
||||||
Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
|
|
||||||
Field(discriminator="impl_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
InferenceHydraConfig = builds(InferenceConfig)
|
|
||||||
|
|
||||||
cs = ConfigStore.instance()
|
|
||||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
|
||||||
|
|
|
@ -4,19 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .api.config import ImplType, InferenceConfig
|
# from .api.config import ImplType, InferenceConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_inference_api_instance(config: InferenceConfig):
|
# async def get_inference_api_instance(config: InferenceConfig):
|
||||||
if config.impl_config.impl_type == ImplType.inline.value:
|
# if config.impl_config.impl_type == ImplType.inline.value:
|
||||||
from .inference import InferenceImpl
|
# from .inference import InferenceImpl
|
||||||
|
|
||||||
return InferenceImpl(config.impl_config)
|
# return InferenceImpl(config.impl_config)
|
||||||
elif config.impl_config.impl_type == ImplType.ollama.value:
|
# elif config.impl_config.impl_type == ImplType.ollama.value:
|
||||||
from .ollama import OllamaInference
|
# from .ollama import OllamaInference
|
||||||
|
|
||||||
return OllamaInference(config.impl_config)
|
# return OllamaInference(config.impl_config)
|
||||||
|
|
||||||
from .client import InferenceClient
|
# from .client import InferenceClient
|
||||||
|
|
||||||
return InferenceClient(config.impl_config.url)
|
# return InferenceClient(config.impl_config.url)
|
||||||
|
|
|
@ -29,7 +29,7 @@ from llama_models.llama3_1.api.model import Transformer
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api.config import CheckpointType, InlineImplConfig
|
from .api.config import CheckpointType, MetaReferenceImplConfig
|
||||||
from .api.datatypes import QuantizationType
|
from .api.datatypes import QuantizationType
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class TokenResult:
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(config: InlineImplConfig):
|
def build(config: MetaReferenceImplConfig):
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
import asyncio
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Union
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import StopReason
|
from llama_models.llama3_1.api.datatypes import StopReason
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from .api.config import InlineImplConfig
|
from .api.config import MetaReferenceImplConfig
|
||||||
from .api.datatypes import (
|
from .api.datatypes import (
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -19,23 +22,35 @@ from .api.endpoints import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
Inference,
|
||||||
)
|
)
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
def get_adapter_impl(config: InlineImplConfig) -> Inference:
|
async def get_adapter_impl(config: MetaReferenceImplConfig):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, InlineImplConfig
|
config, MetaReferenceImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
return InferenceImpl(config)
|
|
||||||
|
impl = MetaReferenceInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
class InferenceImpl(Inference):
|
# there's a single model parallel process running serving the model. for now,
|
||||||
|
# we don't support multiple concurrent requests to this process.
|
||||||
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
def __init__(self, config: InlineImplConfig) -> None:
|
|
||||||
|
class MetaReferenceInferenceImpl(Inference):
|
||||||
|
|
||||||
|
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
model = resolve_model(config.model)
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
|
self.model = model
|
||||||
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
|
@ -44,125 +59,144 @@ class InferenceImpl(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||||
raise NotImplementedError()
|
# top-level server is going to do. make the typing more specific here
|
||||||
|
async def chat_completion(
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
self, request: ChatCompletionRequest
|
||||||
if request.stream:
|
) -> AsyncIterator[
|
||||||
yield ChatCompletionResponseStreamChunk(
|
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||||
event=ChatCompletionResponseEvent(
|
]:
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
model = resolve_model(request.model)
|
||||||
delta="",
|
if model is None:
|
||||||
)
|
raise RuntimeError(
|
||||||
|
f"Unknown model: {request.model}, Run `llama model list`"
|
||||||
|
)
|
||||||
|
elif model.descriptor() != self.model.descriptor():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = []
|
if SEMAPHORE.locked():
|
||||||
logprobs = []
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
stop_reason = None
|
async with SEMAPHORE:
|
||||||
|
if request.stream:
|
||||||
buffer = ""
|
|
||||||
ipython = False
|
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
|
||||||
messages=request.messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
):
|
|
||||||
buffer += token_result.text
|
|
||||||
tokens.append(token_result.token)
|
|
||||||
|
|
||||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
|
||||||
ipython = True
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta=ToolCallDelta(
|
delta="",
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
buffer = buffer[len("<|python_tag|>") :]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not request.stream:
|
tokens = []
|
||||||
if request.logprobs:
|
logprobs = []
|
||||||
logprobs.append(token_result.logprob)
|
|
||||||
|
|
||||||
continue
|
stop_reason = None
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
buffer = ""
|
||||||
stop_reason = StopReason.end_of_turn
|
ipython = False
|
||||||
text = ""
|
|
||||||
elif token_result.text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
else:
|
|
||||||
text = token_result.text
|
|
||||||
|
|
||||||
if ipython:
|
for token_result in self.generator.chat_completion(
|
||||||
delta = ToolCallDelta(
|
messages=request.messages,
|
||||||
content=text,
|
temperature=request.sampling_params.temperature,
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
top_p=request.sampling_params.top_p,
|
||||||
)
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
else:
|
logprobs=request.logprobs,
|
||||||
delta = text
|
):
|
||||||
|
buffer += token_result.text
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||||
|
ipython = True
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
buffer = buffer[len("<|python_tag|>") :]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not request.stream:
|
||||||
|
if request.logprobs:
|
||||||
|
logprobs.append(token_result.logprob)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta = text
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=delta,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
stop_reason = StopReason.out_of_tokens
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||||
stop_reason = StopReason.out_of_tokens
|
# if someone breaks the iteration before coming here we are toast
|
||||||
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
tokens, stop_reason
|
||||||
# if someone breaks the iteration before coming here we are toast
|
|
||||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
|
||||||
if request.stream:
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
if request.stream:
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
for tool_call in message.tool_calls:
|
||||||
else:
|
yield ChatCompletionResponseStreamChunk(
|
||||||
yield ChatCompletionResponse(
|
event=ChatCompletionResponseEvent(
|
||||||
completion_message=message,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
delta=ToolCallDelta(
|
||||||
)
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponse(
|
||||||
|
completion_message=message,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from .api.config import InlineImplConfig
|
from .api.config import MetaReferenceImplConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ModelRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(config: InlineImplConfig):
|
def init_model_cb(config: MetaReferenceImplConfig):
|
||||||
llama = Llama.build(config)
|
llama = Llama.build(config)
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class LlamaModelParallelGenerator:
|
||||||
clear at the callsite why we need to use a context manager.
|
clear at the callsite why we need to use a context manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: InlineImplConfig):
|
def __init__(self, config: MetaReferenceImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
||||||
|
|
||||||
from llama_toolchain.inference.api.config import (
|
from llama_toolchain.inference.api.config import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
InlineImplConfig,
|
MetaReferenceImplConfig,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
from llama_toolchain.inference.api.datatypes import QuantizationType
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ def swiglu_wrapper(
|
||||||
|
|
||||||
def convert_to_quantized_model(
|
def convert_to_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
config: InlineImplConfig,
|
config: MetaReferenceImplConfig,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
if config.quantization.type == QuantizationType.bf16.value:
|
if config.quantization.type == QuantizationType.bf16.value:
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from typing import Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class RunShieldRequest(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
|
|
||||||
|
|
||||||
class SafetyCheck(Protocol):
|
class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue