diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 5a0ff4be9..10a6baf3c 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -18,11 +18,13 @@ from termcolor import cprint from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter -from llama_toolchain.distribution.registry import available_distributions +from llama_toolchain.distribution.registry import ( + available_distributions, + resolve_distribution, +) from llama_toolchain.utils import DISTRIBS_BASE_DIR -from .utils import run_command -DISTRIBS = available_distributions() +from .utils import run_command class DistributionConfigure(Subcommand): @@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand): self.parser.add_argument( "--name", type=str, - help="Mame of the distribution to configure", + help="Name of the distribution to configure", default="local-source", choices=[d.name for d in available_distributions()], ) def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None: - dist = None - for d in DISTRIBS: - if d.name == args.name: - dist = d - break - + dist = resolve_distribution(args.name) if dist is None: self.parser.error(f"Could not find distribution {args.name}") return diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index 98d3d47dd..d169e46cc 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -7,6 +7,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.registry import resolve_distribution class DistributionCreate(Subcommand): @@ -23,7 +24,20 @@ class DistributionCreate(Subcommand): self.parser.set_defaults(func=self._run_distribution_create_cmd) def _add_arguments(self): - pass + self.parser.add_argument( + "--name", + type=str, + help="Name of the distribution to create", + required=True, + ) + # for each ApiSurface the user wants to support, we should + # get the list of available adapters, ask which one the user + # wants to pick and then ask for their configuration. def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None: + dist = resolve_distribution(args.name) + if dist is not None: + self.parser.error(f"Distribution with name {args.name} already exists") + return + raise NotImplementedError() diff --git a/llama_toolchain/cli/distribution/distribution.py b/llama_toolchain/cli/distribution/distribution.py index c553dcf3b..afc5f9341 100644 --- a/llama_toolchain/cli/distribution/distribution.py +++ b/llama_toolchain/cli/distribution/distribution.py @@ -12,6 +12,7 @@ from .configure import DistributionConfigure from .create import DistributionCreate from .install import DistributionInstall from .list import DistributionList +from .start import DistributionStart class DistributionParser(Subcommand): @@ -31,3 +32,4 @@ class DistributionParser(Subcommand): DistributionInstall.create(subparsers) DistributionCreate.create(subparsers) DistributionConfigure.create(subparsers) + DistributionStart.create(subparsers) diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index df60e7ad1..d45456f75 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -11,9 +11,13 @@ import shlex import pkg_resources from llama_toolchain.cli.subcommand import Subcommand -from llama_toolchain.distribution.datatypes import distribution_dependencies -from llama_toolchain.distribution.registry import available_distributions +from llama_toolchain.distribution.distribution import distribution_dependencies +from llama_toolchain.distribution.registry import ( + available_distributions, + resolve_distribution, +) from llama_toolchain.utils import DISTRIBS_BASE_DIR + from .utils import run_command, run_with_pty DISTRIBS = available_distributions() @@ -55,12 +59,7 @@ class DistributionInstall(Subcommand): "distribution/install_distribution.sh", ) - dist = None - for d in DISTRIBS: - if d.name == args.name: - dist = d - break - + dist = resolve_distribution(args.name) if dist is None: self.parser.error(f"Could not find distribution {args.name}") return diff --git a/llama_toolchain/cli/distribution/list.py b/llama_toolchain/cli/distribution/list.py index 39e93d8ec..d20980432 100644 --- a/llama_toolchain/cli/distribution/list.py +++ b/llama_toolchain/cli/distribution/list.py @@ -9,7 +9,7 @@ import argparse from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.cli.table import print_table -from llama_toolchain.distribution.datatypes import distribution_dependencies +from llama_toolchain.distribution.distribution import distribution_dependencies from llama_toolchain.distribution.registry import available_distributions diff --git a/llama_toolchain/cli/distribution/start.py b/llama_toolchain/cli/distribution/start.py new file mode 100644 index 000000000..b567726db --- /dev/null +++ b/llama_toolchain/cli/distribution/start.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import shlex +from pathlib import Path + +import yaml + +from llama_toolchain.cli.subcommand import Subcommand +from llama_toolchain.distribution.registry import resolve_distribution +from llama_toolchain.distribution.server import main as distribution_server_init +from llama_toolchain.utils import DISTRIBS_BASE_DIR + +from .utils import run_command + + +class DistributionStart(Subcommand): + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "start", + prog="llama distribution start", + description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_distribution_start_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--name", + type=str, + help="Name of the distribution to start", + required=True, + ) + self.parser.add_argument( + "--port", + type=int, + help="Port to run the server on. Defaults to 5000", + default=5000, + ) + self.parser.add_argument( + "--disable-ipv6", + action="store_true", + help="Disable IPv6 support", + default=False, + ) + + def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None: + dist = resolve_distribution(args.name) + if dist is None: + self.parser.error(f"Distribution with name {args.name} not found") + return + + config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml" + if not config_yaml.exists(): + raise ValueError( + f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first" + ) + + with open(config_yaml, "r") as fp: + config = yaml.safe_load(fp) + + conda_env = config["conda_env"] + python_exe = run_command(shlex.split("which python")) + # simple check, unfortunate + if conda_env not in python_exe: + raise ValueError( + f"Please re-run start after activating the `{conda_env}` conda environment first" + ) + + distribution_server_init( + dist.name, + config_yaml, + args.port, + disable_ipv6=args.disable_ipv6, + ) + # run_with_pty( + # shlex.split( + # f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000" + # ) + # ) diff --git a/llama_toolchain/cli/distribution/utils.py b/llama_toolchain/cli/distribution/utils.py index 94ed1b0bb..91547ea83 100644 --- a/llama_toolchain/cli/distribution/utils.py +++ b/llama_toolchain/cli/distribution/utils.py @@ -8,21 +8,35 @@ import errno import os import pty import select +import signal import subprocess import sys import termios -import tty + +from termcolor import cprint def run_with_pty(command): - old_settings = termios.tcgetattr(sys.stdin) - - # Create a new pseudo-terminal master, slave = pty.openpty() + old_settings = termios.tcgetattr(sys.stdin) + original_sigint = signal.getsignal(signal.SIGINT) + + ctrl_c_pressed = False + + def sigint_handler(signum, frame): + nonlocal ctrl_c_pressed + ctrl_c_pressed = True + cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"]) + try: - # ensure the terminal does not echo input - tty.setraw(sys.stdin.fileno()) + # Set up the signal handler + signal.signal(signal.SIGINT, sigint_handler) + + new_settings = termios.tcgetattr(sys.stdin) + new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo + new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings) process = subprocess.Popen( command, @@ -30,26 +44,36 @@ def run_with_pty(command): stdout=slave, stderr=slave, universal_newlines=True, + preexec_fn=os.setsid, ) # Close the slave file descriptor as it's now owned by the subprocess os.close(slave) def handle_io(): - while True: - rlist, _, _ = select.select([sys.stdin, master], [], []) + while not ctrl_c_pressed: + try: + rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1) - if sys.stdin in rlist: - data = os.read(sys.stdin.fileno(), 1024) - if not data: # EOF - break - os.write(master, data) + if sys.stdin in rlist: + data = os.read(sys.stdin.fileno(), 1024) + if not data: + break + os.write(master, data) - if master in rlist: - data = os.read(master, 1024) - if not data: - break - os.write(sys.stdout.fileno(), data) + if master in rlist: + data = os.read(master, 1024) + if not data: + break + sys.stdout.buffer.write(data) + sys.stdout.flush() + + except KeyboardInterrupt: + # This will be raised when Ctrl+C is pressed + break + + if process.poll() is not None: + break handle_io() except (EOFError, KeyboardInterrupt): @@ -58,11 +82,14 @@ def run_with_pty(command): if e.errno != errno.EIO: raise finally: - # Restore original terminal settings + # Clean up termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) + signal.signal(signal.SIGINT, original_sigint) - process.wait() - os.close(master) + os.close(master) + if process.poll() is None: + process.terminate() + process.wait() return process.returncode diff --git a/llama_toolchain/cli/llama.py b/llama_toolchain/cli/llama.py index f77855277..4764cf32e 100644 --- a/llama_toolchain/cli/llama.py +++ b/llama_toolchain/cli/llama.py @@ -8,7 +8,8 @@ import argparse from .distribution import DistributionParser from .download import Download -from .inference import InferenceParser + +# from .inference import InferenceParser from .model import ModelParser @@ -29,7 +30,7 @@ class LlamaCLIParser: # Add sub-commands Download.create(subparsers) - InferenceParser.create(subparsers) + # InferenceParser.create(subparsers) ModelParser.create(subparsers) DistributionParser.create(subparsers) diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 4209e6d26..82196a00b 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -5,54 +5,10 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Union +from typing import Dict, List from pydantic import BaseModel, Field from strong_typing.schema import json_schema_type -from typing_extensions import Annotated - - -@json_schema_type -class AdapterType(Enum): - passthrough_api = "passthrough_api" - python_impl = "python_impl" - not_implemented = "not_implemented" - - -@json_schema_type -class PassthroughApiAdapterConfig(BaseModel): - type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value - base_url: str = Field(..., description="The base URL for the llama stack provider") - headers: Dict[str, str] = Field( - default_factory=dict, - description="Headers (e.g., authorization) to send with the request", - ) - - -@json_schema_type -class PythonImplAdapterConfig(BaseModel): - type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value - adapter_id: str - kwargs: Dict[str, Any] = Field( - default_factory=dict, description="kwargs to pass to the entrypoint" - ) - - -@json_schema_type -class NotImplementedAdapterConfig(BaseModel): - type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value - - -# should we define very granular typed classes for each of the PythonImplAdapters we will have? -# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters -AdapterConfig = Annotated[ - Union[ - PassthroughApiAdapterConfig, - NotImplementedAdapterConfig, - PythonImplAdapterConfig, - ], - Field(discriminator="type"), -] @json_schema_type @@ -61,6 +17,13 @@ class ApiSurface(Enum): safety = "safety" +@json_schema_type +class ApiSurfaceEndpoint(BaseModel): + route: str + method: str + name: str + + @json_schema_type class Adapter(BaseModel): api_surface: ApiSurface @@ -108,13 +71,3 @@ class Distribution(BaseModel): default_factory=list, description="Additional pip packages beyond those required by the adapters", ) - - -def distribution_dependencies(distribution: Distribution) -> List[str]: - # only consider SourceAdapters when calculating dependencies - return [ - dep - for adapter in distribution.adapters.values() - if isinstance(adapter, SourceAdapter) - for dep in adapter.pip_packages - ] + distribution.additional_pip_packages diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py new file mode 100644 index 000000000..773d05f26 --- /dev/null +++ b/llama_toolchain/distribution/distribution.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect +from typing import Dict, List + +from llama_toolchain.inference.api.endpoints import Inference +from llama_toolchain.safety.api.endpoints import Safety + +from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter + + +def distribution_dependencies(distribution: Distribution) -> List[str]: + # only consider SourceAdapters when calculating dependencies + return [ + dep + for adapter in distribution.adapters.values() + if isinstance(adapter, SourceAdapter) + for dep in adapter.pip_packages + ] + distribution.additional_pip_packages + + +def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]: + surfaces = {} + + protocols = { + ApiSurface.inference: Inference, + ApiSurface.safety: Safety, + } + + for surface, protocol in protocols.items(): + endpoints = [] + protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + + for name, method in protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + + webmethod = method.__webmethod__ + route = webmethod.route + + # use `post` for all methods right now until we fix up the `webmethod` openapi + # annotation and write our own openapi generator + endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name)) + + surfaces[surface] = endpoints + + return surfaces diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py new file mode 100644 index 000000000..483c08d79 --- /dev/null +++ b/llama_toolchain/distribution/dynamic.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import importlib +from typing import Any, Dict + +from .datatypes import SourceAdapter + + +def instantiate_class_type(fully_qualified_name): + module_name, class_name = fully_qualified_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +# returns a class implementing the protocol corresponding to the ApiSurface +def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]): + module = importlib.import_module(adapter.module) + + config_type = instantiate_class_type(adapter.config_class) + config = config_type(**adapter_config) + return asyncio.run(module.get_adapter_impl(config)) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 1dec45861..ceb101cd4 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List +from typing import List, Optional from llama_toolchain.inference.adapters import available_inference_adapters @@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]: }, ), ] + + +def resolve_distribution(name: str) -> Optional[Distribution]: + for dist in available_distributions(): + if dist.name == name: + return dist + return None diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py new file mode 100644 index 000000000..d45e3b041 --- /dev/null +++ b/llama_toolchain/distribution/server.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import signal +from collections.abc import ( + AsyncGenerator as AsyncGeneratorABC, + AsyncIterator as AsyncIteratorABC, +) +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional + +import fire +import httpx +import yaml +from dotenv import load_dotenv + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from fastapi.routing import APIRoute + +from pydantic import BaseModel +from termcolor import cprint + +from .datatypes import PassthroughApiAdapter +from .distribution import api_surface_endpoints +from .dynamic import instantiate_adapter + +from .registry import resolve_distribution + +load_dotenv() + + +def is_async_iterator_type(typ): + if hasattr(typ, "__origin__"): + origin = typ.__origin__ + if isinstance(origin, type): + return issubclass( + origin, + (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC), + ) + return False + return isinstance( + typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC) + ) + + +def create_sse_event(data: Any) -> str: + if isinstance(data, BaseModel): + data = data.json() + else: + data = json.dumps(data) + + return f"data: {data}\n\n" + + +async def passthrough( + request: Request, + downstream_url: str, + downstream_headers: Optional[Dict[str, str]] = None, +): + client = httpx.AsyncClient() + + headers = dict(request.headers) + headers.pop("host", None) + headers.update(downstream_headers or {}) + + body = await request.body() + + try: + response = await client.request( + method=request.method, + url=downstream_url, + headers=headers, + content=body, + params=request.query_params, + ) + return StreamingResponse( + response.iter_bytes(), + status_code=response.status_code, + headers=dict(response.headers), + ) + finally: + await client.aclose() + + +def handle_sigint(*args, **kwargs): + print("SIGINT or CTRL-C detected. Exiting gracefully", args) + loop = asyncio.get_event_loop() + for task in asyncio.all_tasks(loop): + task.cancel() + loop.stop() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + print("Starting up") + yield + print("Shutting down") + + +def create_dynamic_passthrough( + downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None +): + async def endpoint(request: Request): + return await passthrough(request, downstream_url, downstream_headers) + + return endpoint + + +def create_dynamic_typed_route(func: Any): + hints = get_type_hints(func) + request_model = next(iter(hints.values())) + response_model = hints["return"] + + is_streaming = is_async_iterator_type(response_model) + + if is_streaming: + + async def endpoint(request: request_model): + async def event_generator(): + async for item in func(request): + yield create_sse_event(item) + await asyncio.sleep(0.001) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + else: + + async def endpoint(request: request_model): + return func(request) + + return endpoint + + +def main( + dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False +): + dist = resolve_distribution(dist_name) + if dist is None: + raise ValueError(f"Could not find distribution {dist_name}") + + with open(yaml_config, "r") as fp: + config = yaml.safe_load(fp) + + app = FastAPI() + + all_endpoints = api_surface_endpoints() + + adapter_configs = config["adapters"] + for surface, adapter in dist.adapters.items(): + if surface.value not in adapter_configs: + raise ValueError( + f"Could not find adapter config for {surface}. Please add it to the config" + ) + + adapter_config = adapter_configs[surface.value] + endpoints = all_endpoints[surface] + if isinstance(adapter, PassthroughApiAdapter): + for endpoint in endpoints: + url = adapter.base_url.rstrip("/") + endpoint.route + getattr(app, endpoint.method)(endpoint.route)( + create_dynamic_passthrough(url) + ) + else: + impl = instantiate_adapter(adapter, adapter_config) + for endpoint in endpoints: + if not hasattr(impl, endpoint.name): + # ideally this should be a typing violation already + raise ValueError( + f"Could not find method {endpoint.name} on {impl}!!" + ) + + impl_method = getattr(impl, endpoint.name) + getattr(app, endpoint.method)(endpoint.route, response_model=None)( + create_dynamic_typed_route(impl_method) + ) + + for route in app.routes: + if isinstance(route, APIRoute): + cprint( + f"Serving {next(iter(route.methods))} {route.path}", + "white", + attrs=["bold"], + ) + + signal.signal(signal.SIGINT, handle_sigint) + + import uvicorn + + # FYI this does not do hot-reloads + listen_host = "::" if not disable_ipv6 else "0.0.0.0" + print(f"Listening on {listen_host}:{port}") + uvicorn.run(app, host=listen_host, port=port) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_toolchain/inference/adapters.py b/llama_toolchain/inference/adapters.py index 5b1b5c873..4ab087221 100644 --- a/llama_toolchain/inference/adapters.py +++ b/llama_toolchain/inference/adapters.py @@ -19,7 +19,7 @@ def available_inference_adapters() -> List[Adapter]: "zmq", ], module="llama_toolchain.inference.inference", - config_class="llama_toolchain.inference.inference.InlineImplConfig", + config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", ), SourceAdapter( api_surface=ApiSurface.inference, diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 6bac2d09d..2f01f90db 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -7,9 +7,6 @@ from enum import Enum from typing import Literal, Optional, Union -from hydra.core.config_store import ConfigStore - -from hydra_zen import builds from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat from pydantic import BaseModel, Field @@ -19,13 +16,6 @@ from typing_extensions import Annotated from .datatypes import QuantizationConfig -@json_schema_type -class ImplType(Enum): - inline = "inline" - remote = "remote" - ollama = "ollama" - - @json_schema_type class CheckpointType(Enum): pytorch = "pytorch" @@ -66,8 +56,8 @@ class ModelCheckpointConfig(BaseModel): @json_schema_type -class InlineImplConfig(BaseModel): - impl_type: Literal[ImplType.inline.value] = ImplType.inline.value +class MetaReferenceImplConfig(BaseModel): + model: str checkpoint_config: ModelCheckpointConfig quantization: Optional[QuantizationConfig] = None torch_seed: Optional[int] = None @@ -75,28 +65,7 @@ class InlineImplConfig(BaseModel): max_batch_size: int = 1 -@json_schema_type -class RemoteImplConfig(BaseModel): - impl_type: Literal[ImplType.remote.value] = ImplType.remote.value - url: str = Field(..., description="The URL of the remote module") - - @json_schema_type class OllamaImplConfig(BaseModel): - impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value model: str = Field(..., description="The name of the model in ollama catalog") url: str = Field(..., description="The URL for the ollama server") - - -@json_schema_type -class InferenceConfig(BaseModel): - impl_config: Annotated[ - Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig], - Field(discriminator="impl_type"), - ] - - -InferenceHydraConfig = builds(InferenceConfig) - -cs = ConfigStore.instance() -cs.store(name="inference_config", node=InferenceHydraConfig) diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py index 975de3446..560b99868 100644 --- a/llama_toolchain/inference/api_instance.py +++ b/llama_toolchain/inference/api_instance.py @@ -4,19 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .api.config import ImplType, InferenceConfig +# from .api.config import ImplType, InferenceConfig -async def get_inference_api_instance(config: InferenceConfig): - if config.impl_config.impl_type == ImplType.inline.value: - from .inference import InferenceImpl +# async def get_inference_api_instance(config: InferenceConfig): +# if config.impl_config.impl_type == ImplType.inline.value: +# from .inference import InferenceImpl - return InferenceImpl(config.impl_config) - elif config.impl_config.impl_type == ImplType.ollama.value: - from .ollama import OllamaInference +# return InferenceImpl(config.impl_config) +# elif config.impl_config.impl_type == ImplType.ollama.value: +# from .ollama import OllamaInference - return OllamaInference(config.impl_config) +# return OllamaInference(config.impl_config) - from .client import InferenceClient +# from .client import InferenceClient - return InferenceClient(config.impl_config.url) +# return InferenceClient(config.impl_config.url) diff --git a/llama_toolchain/inference/generation.py b/llama_toolchain/inference/generation.py index a6239edf5..be3b6967d 100644 --- a/llama_toolchain/inference/generation.py +++ b/llama_toolchain/inference/generation.py @@ -29,7 +29,7 @@ from llama_models.llama3_1.api.model import Transformer from llama_models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint -from .api.config import CheckpointType, InlineImplConfig +from .api.config import CheckpointType, MetaReferenceImplConfig from .api.datatypes import QuantizationType @@ -42,7 +42,7 @@ class TokenResult: class Llama: @staticmethod - def build(config: InlineImplConfig): + def build(config: MetaReferenceImplConfig): """ Build a Llama instance by initializing and loading a model checkpoint. diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index 61742a509..beeb6dd65 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -4,11 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +import asyncio + +from typing import AsyncIterator, Union from llama_models.llama3_1.api.datatypes import StopReason +from llama_models.sku_list import resolve_model -from .api.config import InlineImplConfig +from .api.config import MetaReferenceImplConfig from .api.datatypes import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, @@ -19,23 +22,35 @@ from .api.endpoints import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, - CompletionRequest, Inference, ) from .model_parallel import LlamaModelParallelGenerator -def get_adapter_impl(config: InlineImplConfig) -> Inference: +async def get_adapter_impl(config: MetaReferenceImplConfig): assert isinstance( - config, InlineImplConfig + config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" - return InferenceImpl(config) + + impl = MetaReferenceInferenceImpl(config) + await impl.initialize() + return impl -class InferenceImpl(Inference): +# there's a single model parallel process running serving the model. for now, +# we don't support multiple concurrent requests to this process. +SEMAPHORE = asyncio.Semaphore(1) - def __init__(self, config: InlineImplConfig) -> None: + +class MetaReferenceInferenceImpl(Inference): + + def __init__(self, config: MetaReferenceImplConfig) -> None: self.config = config + model = resolve_model(config.model) + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") + self.model = model + # verify that the checkpoint actually is for this model lol async def initialize(self) -> None: self.generator = LlamaModelParallelGenerator(self.config) @@ -44,125 +59,144 @@ class InferenceImpl(Inference): async def shutdown(self) -> None: self.generator.stop() - async def completion(self, request: CompletionRequest) -> AsyncGenerator: - raise NotImplementedError() - - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - if request.stream: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) + # hm, when stream=False, we should not be doing SSE :/ which is what the + # top-level server is going to do. make the typing more specific here + async def chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncIterator[ + Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] + ]: + model = resolve_model(request.model) + if model is None: + raise RuntimeError( + f"Unknown model: {request.model}, Run `llama model list`" + ) + elif model.descriptor() != self.model.descriptor(): + raise RuntimeError( + f"Model mismatch: {request.model} != {self.model.descriptor()}" ) - tokens = [] - logprobs = [] + if SEMAPHORE.locked(): + raise RuntimeError("Only one concurrent request is supported") - stop_reason = None - - buffer = "" - ipython = False - - for token_result in self.generator.chat_completion( - messages=request.messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - ): - buffer += token_result.text - tokens.append(token_result.token) - - if not ipython and buffer.startswith("<|python_tag|>"): - ipython = True + async with SEMAPHORE: + if request.stream: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), + event_type=ChatCompletionResponseEventType.start, + delta="", ) ) - buffer = buffer[len("<|python_tag|>") :] - continue - if not request.stream: - if request.logprobs: - logprobs.append(token_result.logprob) + tokens = [] + logprobs = [] - continue + stop_reason = None - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text + buffer = "" + ipython = False - if ipython: - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = text + for token_result in self.generator.chat_completion( + messages=request.messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + ): + buffer += token_result.text + tokens.append(token_result.token) + + if not ipython and buffer.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + buffer = buffer[len("<|python_tag|>") :] + continue + + if not request.stream: + if request.logprobs: + logprobs.append(token_result.logprob) + + continue + + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + if ipython: + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + else: + delta = text + + if stop_reason is None: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) if stop_reason is None: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) + stop_reason = StopReason.out_of_tokens - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - # TODO(ashwin): parse tool calls separately here and report errors? - # if someone breaks the iteration before coming here we are toast - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - if request.stream: - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) + # TODO(ashwin): parse tool calls separately here and report errors? + # if someone breaks the iteration before coming here we are toast + message = self.generator.formatter.decode_assistant_message( + tokens, stop_reason ) + if request.stream: + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, + ) + ) - # TODO(ashwin): what else do we need to send out here when everything finishes? - else: - yield ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, - ) + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + + # TODO(ashwin): what else do we need to send out here when everything finishes? + else: + yield ChatCompletionResponse( + completion_message=message, + logprobs=logprobs if request.logprobs else None, + ) diff --git a/llama_toolchain/inference/model_parallel.py b/llama_toolchain/inference/model_parallel.py index a5066df82..8426f7890 100644 --- a/llama_toolchain/inference/model_parallel.py +++ b/llama_toolchain/inference/model_parallel.py @@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.tokenizer import Tokenizer -from .api.config import InlineImplConfig +from .api.config import MetaReferenceImplConfig from .generation import Llama from .parallel_utils import ModelParallelProcessGroup @@ -42,7 +42,7 @@ class ModelRunner: ) -def init_model_cb(config: InlineImplConfig): +def init_model_cb(config: MetaReferenceImplConfig): llama = Llama.build(config) return ModelRunner(llama) @@ -58,7 +58,7 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, config: InlineImplConfig): + def __init__(self, config: MetaReferenceImplConfig): self.config = config # this is a hack because Agent's loop uses this to tokenize and check if input is too long diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index 195ba1e96..583123df6 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, - InlineImplConfig, + MetaReferenceImplConfig, ) from llama_toolchain.inference.api.datatypes import QuantizationType @@ -46,7 +46,7 @@ def swiglu_wrapper( def convert_to_quantized_model( model: Transformer, - config: InlineImplConfig, + config: MetaReferenceImplConfig, fp8_activation_scale_ub: Optional[float] = 1200.0, ) -> Transformer: if config.quantization.type == QuantizationType.bf16.value: diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py index 8558ed8fd..0f50abae3 100644 --- a/llama_toolchain/safety/api/endpoints.py +++ b/llama_toolchain/safety/api/endpoints.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from .datatypes import * # noqa: F403 -from typing import Protocol +from typing import List, Protocol from llama_models.llama3_1.api.datatypes import Message @@ -19,7 +19,7 @@ class RunShieldRequest(BaseModel): messages: List[Message] -class SafetyCheck(Protocol): +class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield(