Distribution server now functioning

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:37:40 -07:00
parent 041cafbee3
commit 2cf9915806
21 changed files with 635 additions and 266 deletions

View file

@ -18,11 +18,13 @@ from termcolor import cprint
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
from llama_toolchain.distribution.registry import available_distributions from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
DISTRIBS = available_distributions() from .utils import run_command
class DistributionConfigure(Subcommand): class DistributionConfigure(Subcommand):
@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--name", "--name",
type=str, type=str,
help="Mame of the distribution to configure", help="Name of the distribution to configure",
default="local-source", default="local-source",
choices=[d.name for d in available_distributions()], choices=[d.name for d in available_distributions()],
) )
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
dist = None dist = resolve_distribution(args.name)
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
if dist is None: if dist is None:
self.parser.error(f"Could not find distribution {args.name}") self.parser.error(f"Could not find distribution {args.name}")
return return

View file

@ -7,6 +7,7 @@
import argparse import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
class DistributionCreate(Subcommand): class DistributionCreate(Subcommand):
@ -23,7 +24,20 @@ class DistributionCreate(Subcommand):
self.parser.set_defaults(func=self._run_distribution_create_cmd) self.parser.set_defaults(func=self._run_distribution_create_cmd)
def _add_arguments(self): def _add_arguments(self):
pass self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to create",
required=True,
)
# for each ApiSurface the user wants to support, we should
# get the list of available adapters, ask which one the user
# wants to pick and then ask for their configuration.
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
dist = resolve_distribution(args.name)
if dist is not None:
self.parser.error(f"Distribution with name {args.name} already exists")
return
raise NotImplementedError() raise NotImplementedError()

View file

@ -12,6 +12,7 @@ from .configure import DistributionConfigure
from .create import DistributionCreate from .create import DistributionCreate
from .install import DistributionInstall from .install import DistributionInstall
from .list import DistributionList from .list import DistributionList
from .start import DistributionStart
class DistributionParser(Subcommand): class DistributionParser(Subcommand):
@ -31,3 +32,4 @@ class DistributionParser(Subcommand):
DistributionInstall.create(subparsers) DistributionInstall.create(subparsers)
DistributionCreate.create(subparsers) DistributionCreate.create(subparsers)
DistributionConfigure.create(subparsers) DistributionConfigure.create(subparsers)
DistributionStart.create(subparsers)

View file

@ -11,9 +11,13 @@ import shlex
import pkg_resources import pkg_resources
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.datatypes import distribution_dependencies from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions from llama_toolchain.distribution.registry import (
available_distributions,
resolve_distribution,
)
from llama_toolchain.utils import DISTRIBS_BASE_DIR from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command, run_with_pty from .utils import run_command, run_with_pty
DISTRIBS = available_distributions() DISTRIBS = available_distributions()
@ -55,12 +59,7 @@ class DistributionInstall(Subcommand):
"distribution/install_distribution.sh", "distribution/install_distribution.sh",
) )
dist = None dist = resolve_distribution(args.name)
for d in DISTRIBS:
if d.name == args.name:
dist = d
break
if dist is None: if dist is None:
self.parser.error(f"Could not find distribution {args.name}") self.parser.error(f"Could not find distribution {args.name}")
return return

View file

@ -9,7 +9,7 @@ import argparse
from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
from llama_toolchain.distribution.datatypes import distribution_dependencies from llama_toolchain.distribution.distribution import distribution_dependencies
from llama_toolchain.distribution.registry import available_distributions from llama_toolchain.distribution.registry import available_distributions

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shlex
from pathlib import Path
import yaml
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.distribution.registry import resolve_distribution
from llama_toolchain.distribution.server import main as distribution_server_init
from llama_toolchain.utils import DISTRIBS_BASE_DIR
from .utils import run_command
class DistributionStart(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"start",
prog="llama distribution start",
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_distribution_start_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--name",
type=str,
help="Name of the distribution to start",
required=True,
)
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
)
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
dist = resolve_distribution(args.name)
if dist is None:
self.parser.error(f"Distribution with name {args.name} not found")
return
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
if not config_yaml.exists():
raise ValueError(
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
)
with open(config_yaml, "r") as fp:
config = yaml.safe_load(fp)
conda_env = config["conda_env"]
python_exe = run_command(shlex.split("which python"))
# simple check, unfortunate
if conda_env not in python_exe:
raise ValueError(
f"Please re-run start after activating the `{conda_env}` conda environment first"
)
distribution_server_init(
dist.name,
config_yaml,
args.port,
disable_ipv6=args.disable_ipv6,
)
# run_with_pty(
# shlex.split(
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
# )
# )

View file

@ -8,21 +8,35 @@ import errno
import os import os
import pty import pty
import select import select
import signal
import subprocess import subprocess
import sys import sys
import termios import termios
import tty
from termcolor import cprint
def run_with_pty(command): def run_with_pty(command):
old_settings = termios.tcgetattr(sys.stdin)
# Create a new pseudo-terminal
master, slave = pty.openpty() master, slave = pty.openpty()
old_settings = termios.tcgetattr(sys.stdin)
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
ctrl_c_pressed = True
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
try: try:
# ensure the terminal does not echo input # Set up the signal handler
tty.setraw(sys.stdin.fileno()) signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen( process = subprocess.Popen(
command, command,
@ -30,26 +44,36 @@ def run_with_pty(command):
stdout=slave, stdout=slave,
stderr=slave, stderr=slave,
universal_newlines=True, universal_newlines=True,
preexec_fn=os.setsid,
) )
# Close the slave file descriptor as it's now owned by the subprocess # Close the slave file descriptor as it's now owned by the subprocess
os.close(slave) os.close(slave)
def handle_io(): def handle_io():
while True: while not ctrl_c_pressed:
rlist, _, _ = select.select([sys.stdin, master], [], []) try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist: if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024) data = os.read(sys.stdin.fileno(), 1024)
if not data: # EOF if not data:
break break
os.write(master, data) os.write(master, data)
if master in rlist: if master in rlist:
data = os.read(master, 1024) data = os.read(master, 1024)
if not data: if not data:
break break
os.write(sys.stdout.fileno(), data) sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io() handle_io()
except (EOFError, KeyboardInterrupt): except (EOFError, KeyboardInterrupt):
@ -58,11 +82,14 @@ def run_with_pty(command):
if e.errno != errno.EIO: if e.errno != errno.EIO:
raise raise
finally: finally:
# Restore original terminal settings # Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
process.wait() os.close(master)
os.close(master) if process.poll() is None:
process.terminate()
process.wait()
return process.returncode return process.returncode

View file

@ -8,7 +8,8 @@ import argparse
from .distribution import DistributionParser from .distribution import DistributionParser
from .download import Download from .download import Download
from .inference import InferenceParser
# from .inference import InferenceParser
from .model import ModelParser from .model import ModelParser
@ -29,7 +30,7 @@ class LlamaCLIParser:
# Add sub-commands # Add sub-commands
Download.create(subparsers) Download.create(subparsers)
InferenceParser.create(subparsers) # InferenceParser.create(subparsers)
ModelParser.create(subparsers) ModelParser.create(subparsers)
DistributionParser.create(subparsers) DistributionParser.create(subparsers)

View file

@ -5,54 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Union from typing import Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
@json_schema_type
class AdapterType(Enum):
passthrough_api = "passthrough_api"
python_impl = "python_impl"
not_implemented = "not_implemented"
@json_schema_type
class PassthroughApiAdapterConfig(BaseModel):
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
@json_schema_type
class PythonImplAdapterConfig(BaseModel):
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
adapter_id: str
kwargs: Dict[str, Any] = Field(
default_factory=dict, description="kwargs to pass to the entrypoint"
)
@json_schema_type
class NotImplementedAdapterConfig(BaseModel):
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
AdapterConfig = Annotated[
Union[
PassthroughApiAdapterConfig,
NotImplementedAdapterConfig,
PythonImplAdapterConfig,
],
Field(discriminator="type"),
]
@json_schema_type @json_schema_type
@ -61,6 +17,13 @@ class ApiSurface(Enum):
safety = "safety" safety = "safety"
@json_schema_type
class ApiSurfaceEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type @json_schema_type
class Adapter(BaseModel): class Adapter(BaseModel):
api_surface: ApiSurface api_surface: ApiSurface
@ -108,13 +71,3 @@ class Distribution(BaseModel):
default_factory=list, default_factory=list,
description="Additional pip packages beyond those required by the adapters", description="Additional pip packages beyond those required by the adapters",
) )
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
surfaces = {}
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
}
for surface, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
surfaces[surface] = endpoints
return surfaces

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import SourceAdapter
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the ApiSurface
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
module = importlib.import_module(adapter.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config))

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List from typing import List, Optional
from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.inference.adapters import available_inference_adapters
@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]:
}, },
), ),
] ]
def resolve_distribution(name: str) -> Optional[Distribution]:
for dist in available_distributions():
if dist.name == name:
return dist
return None

View file

@ -0,0 +1,202 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import signal
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
import fire
import httpx
import yaml
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .dynamic import instantiate_adapter
from .registry import resolve_distribution
load_dotenv()
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
client = httpx.AsyncClient()
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
body = await request.body()
try:
response = await client.request(
method=request.method,
url=downstream_url,
headers=headers,
content=body,
params=request.query_params,
)
return StreamingResponse(
response.iter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
)
finally:
await client.aclose()
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(request: request_model):
async def event_generator():
async for item in func(request):
yield create_sse_event(item)
await asyncio.sleep(0.001)
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
async def endpoint(request: request_model):
return func(request)
return endpoint
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
dist = resolve_distribution(dist_name)
if dist is None:
raise ValueError(f"Could not find distribution {dist_name}")
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
app = FastAPI()
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
for surface, adapter in dist.adapters.items():
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
)
adapter_config = adapter_configs[surface.value]
endpoints = all_endpoints[surface]
if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
impl = instantiate_adapter(adapter, adapter_config)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -19,7 +19,7 @@ def available_inference_adapters() -> List[Adapter]:
"zmq", "zmq",
], ],
module="llama_toolchain.inference.inference", module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.InlineImplConfig", config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
), ),
SourceAdapter( SourceAdapter(
api_surface=ApiSurface.inference, api_surface=ApiSurface.inference,

View file

@ -7,9 +7,6 @@
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from hydra.core.config_store import ConfigStore
from hydra_zen import builds
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -19,13 +16,6 @@ from typing_extensions import Annotated
from .datatypes import QuantizationConfig from .datatypes import QuantizationConfig
@json_schema_type
class ImplType(Enum):
inline = "inline"
remote = "remote"
ollama = "ollama"
@json_schema_type @json_schema_type
class CheckpointType(Enum): class CheckpointType(Enum):
pytorch = "pytorch" pytorch = "pytorch"
@ -66,8 +56,8 @@ class ModelCheckpointConfig(BaseModel):
@json_schema_type @json_schema_type
class InlineImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value model: str
checkpoint_config: ModelCheckpointConfig checkpoint_config: ModelCheckpointConfig
quantization: Optional[QuantizationConfig] = None quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
@ -75,28 +65,7 @@ class InlineImplConfig(BaseModel):
max_batch_size: int = 1 max_batch_size: int = 1
@json_schema_type
class RemoteImplConfig(BaseModel):
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
url: str = Field(..., description="The URL of the remote module")
@json_schema_type @json_schema_type
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
model: str = Field(..., description="The name of the model in ollama catalog") model: str = Field(..., description="The name of the model in ollama catalog")
url: str = Field(..., description="The URL for the ollama server") url: str = Field(..., description="The URL for the ollama server")
@json_schema_type
class InferenceConfig(BaseModel):
impl_config: Annotated[
Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
Field(discriminator="impl_type"),
]
InferenceHydraConfig = builds(InferenceConfig)
cs = ConfigStore.instance()
cs.store(name="inference_config", node=InferenceHydraConfig)

View file

@ -4,19 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .api.config import ImplType, InferenceConfig # from .api.config import ImplType, InferenceConfig
async def get_inference_api_instance(config: InferenceConfig): # async def get_inference_api_instance(config: InferenceConfig):
if config.impl_config.impl_type == ImplType.inline.value: # if config.impl_config.impl_type == ImplType.inline.value:
from .inference import InferenceImpl # from .inference import InferenceImpl
return InferenceImpl(config.impl_config) # return InferenceImpl(config.impl_config)
elif config.impl_config.impl_type == ImplType.ollama.value: # elif config.impl_config.impl_type == ImplType.ollama.value:
from .ollama import OllamaInference # from .ollama import OllamaInference
return OllamaInference(config.impl_config) # return OllamaInference(config.impl_config)
from .client import InferenceClient # from .client import InferenceClient
return InferenceClient(config.impl_config.url) # return InferenceClient(config.impl_config.url)

View file

@ -29,7 +29,7 @@ from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from termcolor import cprint from termcolor import cprint
from .api.config import CheckpointType, InlineImplConfig from .api.config import CheckpointType, MetaReferenceImplConfig
from .api.datatypes import QuantizationType from .api.datatypes import QuantizationType
@ -42,7 +42,7 @@ class TokenResult:
class Llama: class Llama:
@staticmethod @staticmethod
def build(config: InlineImplConfig): def build(config: MetaReferenceImplConfig):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.

View file

@ -4,11 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3_1.api.datatypes import StopReason from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from .api.config import InlineImplConfig from .api.config import MetaReferenceImplConfig
from .api.datatypes import ( from .api.datatypes import (
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -19,23 +22,35 @@ from .api.endpoints import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference, Inference,
) )
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
def get_adapter_impl(config: InlineImplConfig) -> Inference: async def get_adapter_impl(config: MetaReferenceImplConfig):
assert isinstance( assert isinstance(
config, InlineImplConfig config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
return InferenceImpl(config)
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
class InferenceImpl(Inference): # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
def __init__(self, config: InlineImplConfig) -> None:
class MetaReferenceInferenceImpl(Inference):
def __init__(self, config: MetaReferenceImplConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
@ -44,125 +59,144 @@ class InferenceImpl(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()
async def completion(self, request: CompletionRequest) -> AsyncGenerator: # hm, when stream=False, we should not be doing SSE :/ which is what the
raise NotImplementedError() # top-level server is going to do. make the typing more specific here
async def chat_completion(
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: self, request: ChatCompletionRequest
if request.stream: ) -> AsyncIterator[
yield ChatCompletionResponseStreamChunk( Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
event=ChatCompletionResponseEvent( ]:
event_type=ChatCompletionResponseEventType.start, model = resolve_model(request.model)
delta="", if model is None:
) raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
) )
tokens = [] if SEMAPHORE.locked():
logprobs = [] raise RuntimeError("Only one concurrent request is supported")
stop_reason = None async with SEMAPHORE:
if request.stream:
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.start,
delta=ToolCallDelta( delta="",
content="",
parse_status=ToolCallParseStatus.started,
),
) )
) )
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream: tokens = []
if request.logprobs: logprobs = []
logprobs.append(token_result.logprob)
continue stop_reason = None
if token_result.text == "<|eot_id|>": buffer = ""
stop_reason = StopReason.end_of_turn ipython = False
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython: for token_result in self.generator.chat_completion(
delta = ToolCallDelta( messages=request.messages,
content=text, temperature=request.sampling_params.temperature,
parse_status=ToolCallParseStatus.in_progress, top_p=request.sampling_params.top_p,
) max_gen_len=request.sampling_params.max_tokens,
else: logprobs=request.logprobs,
delta = text ):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
logprobs.append(token_result.logprob)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None: if stop_reason is None:
yield ChatCompletionResponseStreamChunk( stop_reason = StopReason.out_of_tokens
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
if stop_reason is None: # TODO(ashwin): parse tool calls separately here and report errors?
stop_reason = StopReason.out_of_tokens # if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
# TODO(ashwin): parse tool calls separately here and report errors? tokens, stop_reason
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
) )
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes? for tool_call in message.tool_calls:
else: yield ChatCompletionResponseStreamChunk(
yield ChatCompletionResponse( event=ChatCompletionResponseEvent(
completion_message=message, event_type=ChatCompletionResponseEventType.progress,
logprobs=logprobs if request.logprobs else None, delta=ToolCallDelta(
) content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)

View file

@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3_1.api.tokenizer import Tokenizer
from .api.config import InlineImplConfig from .api.config import MetaReferenceImplConfig
from .generation import Llama from .generation import Llama
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
@ -42,7 +42,7 @@ class ModelRunner:
) )
def init_model_cb(config: InlineImplConfig): def init_model_cb(config: MetaReferenceImplConfig):
llama = Llama.build(config) llama = Llama.build(config)
return ModelRunner(llama) return ModelRunner(llama)
@ -58,7 +58,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager. clear at the callsite why we need to use a context manager.
""" """
def __init__(self, config: InlineImplConfig): def __init__(self, config: MetaReferenceImplConfig):
self.config = config self.config = config
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
InlineImplConfig, MetaReferenceImplConfig,
) )
from llama_toolchain.inference.api.datatypes import QuantizationType from llama_toolchain.inference.api.datatypes import QuantizationType
@ -46,7 +46,7 @@ def swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
config: InlineImplConfig, config: MetaReferenceImplConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value: if config.quantization.type == QuantizationType.bf16.value:

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from typing import Protocol from typing import List, Protocol
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
@ -19,7 +19,7 @@ class RunShieldRequest(BaseModel):
messages: List[Message] messages: List[Message]
class SafetyCheck(Protocol): class Safety(Protocol):
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run_shield")
async def run_shield( async def run_shield(