fix: Use lifespan shutdown to handle shutdown

uvicorn will already handle SIGINT and SIGTERM by yielding in lifespan()
handler, which allows to proceed to shutdown handler.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-10 19:57:13 +00:00
parent 961c87f2c7
commit e4acdf6d54

View file

@ -6,11 +6,9 @@
import argparse import argparse
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import os import os
import signal
import sys import sys
import traceback import traceback
import warnings import warnings
@ -115,35 +113,12 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
) )
def handle_signal(app, signum, _) -> None: async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
Handle incoming signals and initiate a graceful shutdown of the application.
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name) logger.info("Shutting down %s", impl_name)
@ -157,36 +132,13 @@ def handle_signal(app, signum, _) -> None:
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e}) logger.exception("Failed to shutdown %s: %s", impl_name, {e})
loop = asyncio.get_running_loop()
try:
# Gather all running tasks
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up") logger.info("Starting up")
yield yield
logger.info("Shutting down") logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values(): await shutdown(app)
await impl.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -436,8 +388,6 @@ def main():
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls