mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix: Use lifespan shutdown to handle shutdown
uvicorn will already handle SIGINT and SIGTERM by yielding in lifespan() handler, which allows to proceed to shutdown handler. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
961c87f2c7
commit
e4acdf6d54
1 changed files with 17 additions and 67 deletions
|
@ -6,11 +6,9 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -115,35 +113,12 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_signal(app, signum, _) -> None:
|
async def shutdown(app):
|
||||||
|
"""Initiate a graceful shutdown of the application.
|
||||||
|
|
||||||
|
Handled by the lifespan context manager. The shutdown process involves
|
||||||
|
shutting down all implementations registered in the application.
|
||||||
"""
|
"""
|
||||||
Handle incoming signals and initiate a graceful shutdown of the application.
|
|
||||||
|
|
||||||
This function is intended to be used as a signal handler for various signals
|
|
||||||
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
|
|
||||||
indicating the received signal and initiate a shutdown process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app: The application instance containing implementations to be shut down.
|
|
||||||
signum (int): The signal number received.
|
|
||||||
frame: The current stack frame (not used in this function).
|
|
||||||
|
|
||||||
The shutdown process involves:
|
|
||||||
- Shutting down all implementations registered in the application.
|
|
||||||
- Gathering all running asyncio tasks.
|
|
||||||
- Cancelling all gathered tasks.
|
|
||||||
- Waiting for all tasks to finish.
|
|
||||||
- Stopping the event loop.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function schedules the shutdown process as an asyncio task and does
|
|
||||||
not block the current execution.
|
|
||||||
"""
|
|
||||||
signame = signal.Signals(signum).name
|
|
||||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
|
||||||
|
|
||||||
async def shutdown():
|
|
||||||
# Gracefully shut down implementations
|
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info("Shutting down %s", impl_name)
|
logger.info("Shutting down %s", impl_name)
|
||||||
|
@ -157,36 +132,13 @@ def handle_signal(app, signum, _) -> None:
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
try:
|
|
||||||
# Gather all running tasks
|
|
||||||
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
|
|
||||||
|
|
||||||
# Cancel all tasks
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
# Wait for all tasks to finish
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.exception("Timeout while waiting for tasks to finish")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
loop.stop()
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
loop.create_task(shutdown())
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Starting up")
|
logger.info("Starting up")
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
for impl in app.__llama_stack_impls__.values():
|
await shutdown(app)
|
||||||
await impl.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
@ -436,8 +388,6 @@ def main():
|
||||||
|
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
|
||||||
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
|
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue