From d3e269fcf2f0420fb210240b5ff51de05d12e92a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 2 Aug 2024 14:18:25 -0700 Subject: [PATCH] Remove inference uvicorn server entrypoint and llama inference CLI command --- llama_toolchain/cli/inference/inference.py | 4 +- llama_toolchain/cli/inference/start.py | 57 ---------- llama_toolchain/inference/server.py | 119 --------------------- 3 files changed, 1 insertion(+), 179 deletions(-) delete mode 100644 llama_toolchain/cli/inference/start.py delete mode 100644 llama_toolchain/inference/server.py diff --git a/llama_toolchain/cli/inference/inference.py b/llama_toolchain/cli/inference/inference.py index c42771e6b..51a82b1f0 100644 --- a/llama_toolchain/cli/inference/inference.py +++ b/llama_toolchain/cli/inference/inference.py @@ -8,7 +8,6 @@ import argparse import textwrap from llama_toolchain.cli.inference.configure import InferenceConfigure -from llama_toolchain.cli.inference.start import InferenceStart from llama_toolchain.cli.subcommand import Subcommand @@ -31,6 +30,5 @@ class InferenceParser(Subcommand): subparsers = self.parser.add_subparsers(title="inference_subcommands") - # Add sub-commandsa - InferenceStart.create(subparsers) + # Add sub-commands InferenceConfigure.create(subparsers) diff --git a/llama_toolchain/cli/inference/start.py b/llama_toolchain/cli/inference/start.py deleted file mode 100644 index 820b9534c..000000000 --- a/llama_toolchain/cli/inference/start.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import textwrap - -from llama_toolchain.cli.subcommand import Subcommand - -from llama_toolchain.inference.server import main as inference_server_init - - -class InferenceStart(Subcommand): - """Llama Inference cli for starting inference server""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "start", - prog="llama inference start", - description="Start an inference server", - epilog=textwrap.dedent( - """ - Example: - llama inference start - """ - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_inference_start_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "--port", - type=int, - help="Port to run the server on. Defaults to 5000", - default=5000, - ) - self.parser.add_argument( - "--disable-ipv6", - action="store_true", - help="Disable IPv6 support", - default=False, - ) - self.parser.add_argument( - "--config", type=str, help="Path to config file", default="inference" - ) - - def _run_inference_start_cmd(self, args: argparse.Namespace) -> None: - inference_server_init( - config_path=args.config, - port=args.port, - disable_ipv6=args.disable_ipv6, - ) diff --git a/llama_toolchain/inference/server.py b/llama_toolchain/inference/server.py deleted file mode 100644 index c7adce267..000000000 --- a/llama_toolchain/inference/server.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import signal - -import fire - -from dotenv import load_dotenv - -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import StreamingResponse - -from hydra_zen import instantiate - -from llama_toolchain.utils import get_default_config_dir, parse_config -from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk - -from .api_instance import get_inference_api_instance - - -load_dotenv() - - -GLOBAL_CONFIG = None - - -def get_config(): - return GLOBAL_CONFIG - - -def handle_sigint(*args, **kwargs): - print("SIGINT or CTRL-C detected. Exiting gracefully", args) - loop = asyncio.get_event_loop() - for task in asyncio.all_tasks(loop): - task.cancel() - loop.stop() - - -app = FastAPI() - - -@app.on_event("startup") -async def startup(): - global InferenceApiInstance - - config = get_config() - - inference_config = instantiate(config["inference_config"]) - InferenceApiInstance = await get_inference_api_instance( - inference_config, - ) - await InferenceApiInstance.initialize() - - -@app.on_event("shutdown") -async def shutdown(): - global InferenceApiInstance - - print("shutting down") - await InferenceApiInstance.shutdown() - - -# there's a single model parallel process running serving the model. for now, -# we don't support multiple concurrent requests to this process. -semaphore = asyncio.Semaphore(1) - - -@app.post( - "/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk -) -def chat_completion(request: Request, exec_request: ChatCompletionRequest): - if semaphore.locked(): - raise HTTPException( - status_code=429, - detail="Only a single concurrent request allowed right now.", - ) - - async def sse_generator(event_gen): - try: - async for event in event_gen: - yield f"data: {event.json()}\n\n" - await asyncio.sleep(0.01) - except asyncio.CancelledError: - print("Generator cancelled") - await event_gen.aclose() - finally: - semaphore.release() - - async def event_gen(): - async for event in InferenceApiInstance.chat_completion(exec_request): - yield event - - return StreamingResponse( - sse_generator(event_gen()), - media_type="text/event-stream", - ) - - -def main(config_path: str, port: int = 5000, disable_ipv6: bool = False): - global GLOBAL_CONFIG - config_dir = get_default_config_dir() - GLOBAL_CONFIG = parse_config(config_dir, config_path) - - signal.signal(signal.SIGINT, handle_sigint) - - import uvicorn - - # FYI this does not do hot-reloads - listen_host = "::" if not disable_ipv6 else "0.0.0.0" - print(f"Listening on {listen_host}:{port}") - uvicorn.run(app, host=listen_host, port=port) - - -if __name__ == "__main__": - fire.Fire(main)