diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f819d446f..ea8723365 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -6,11 +6,9 @@ import argparse import asyncio -import functools import inspect import json import os -import signal import sys import traceback import warnings @@ -118,69 +116,24 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) -def handle_signal(app, signum, _) -> None: +async def shutdown(app): + """Initiate a graceful shutdown of the application. + + Handled by the lifespan context manager. The shutdown process involves + shutting down all implementations registered in the application. """ - Handle incoming signals and initiate a graceful shutdown of the application. - - This function is intended to be used as a signal handler for various signals - (e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message - indicating the received signal and initiate a shutdown process. - - Args: - app: The application instance containing implementations to be shut down. - signum (int): The signal number received. - frame: The current stack frame (not used in this function). - - The shutdown process involves: - - Shutting down all implementations registered in the application. - - Gathering all running asyncio tasks. - - Cancelling all gathered tasks. - - Waiting for all tasks to finish. - - Stopping the event loop. - - Note: - This function schedules the shutdown process as an asyncio task and does - not block the current execution. - """ - signame = signal.Signals(signum).name - logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") - - async def shutdown(): + for impl in app.__llama_stack_impls__.values(): + impl_name = impl.__class__.__name__ + logger.info("Shutting down %s", impl_name) try: - # Gracefully shut down implementations - for impl in app.__llama_stack_impls__.values(): - impl_name = impl.__class__.__name__ - logger.info("Shutting down %s", impl_name) - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning("No shutdown method for %s", impl_name) - except asyncio.TimeoutError: - logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) - except Exception as e: - logger.exception("Failed to shutdown %s: %s", impl_name, {e}) - - # Gather all running tasks - loop = asyncio.get_running_loop() - tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()] - - # Cancel all tasks - for task in tasks: - task.cancel() - - # Wait for all tasks to finish - try: - await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) - except asyncio.TimeoutError: - logger.exception("Timeout while waiting for tasks to finish") - except asyncio.CancelledError: - pass - finally: - loop.stop() - - loop = asyncio.get_running_loop() - loop.create_task(shutdown()) + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning("No shutdown method for %s", impl_name) + except asyncio.TimeoutError: + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) + except (Exception, asyncio.CancelledError) as e: + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) @asynccontextmanager @@ -188,8 +141,7 @@ async def lifespan(app: FastAPI): logger.info("Starting up") yield logger.info("Shutting down") - for impl in app.__llama_stack_impls__.values(): - await impl.shutdown() + await shutdown(app) def is_streaming_request(func_name: str, request: Request, **kwargs): @@ -266,7 +218,7 @@ class TracingMiddleware: self.app = app async def __call__(self, scope, receive, send): - path = scope["path"] + path = scope.get("path", "") await start_trace(path, {"__location__": "server"}) try: return await self.app(scope, receive, send) @@ -439,8 +391,6 @@ def main(): app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) - signal.signal(signal.SIGTERM, functools.partial(handle_signal, app)) app.__llama_stack_impls__ = impls @@ -471,6 +421,7 @@ def main(): "app": app, "host": listen_host, "port": port, + "lifespan": "on", } if ssl_config: uvicorn_config.update(ssl_config)