diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index c4ef79a69..b85c463ae 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -143,23 +143,23 @@ def handle_signal(app, signum, _) -> None: logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") async def shutdown(): - try: - # Gracefully shut down implementations - for impl in app.__llama_stack_impls__.values(): - impl_name = impl.__class__.__name__ - logger.info("Shutting down %s", impl_name) - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning("No shutdown method for %s", impl_name) - except asyncio.TimeoutError: - logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) - except Exception as e: - logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + # Gracefully shut down implementations + for impl in app.__llama_stack_impls__.values(): + impl_name = impl.__class__.__name__ + logger.info("Shutting down %s", impl_name) + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning("No shutdown method for %s", impl_name) + except asyncio.TimeoutError: + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) + except (Exception, asyncio.CancelledError) as e: + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + loop = asyncio.get_running_loop() + try: # Gather all running tasks - loop = asyncio.get_running_loop() tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()] # Cancel all tasks