From c480d4991775501d177478aa4923ce2bb695f3f0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Jul 2025 18:53:49 -0700 Subject: [PATCH] fix(registry): ensure clean shutdown --- llama_stack/distribution/server/server.py | 14 ++----------- llama_stack/distribution/stack.py | 25 +++++++++++++++++++++-- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 26ea5f90c..9259fc243 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -56,6 +56,7 @@ from llama_stack.distribution.stack import ( cast_image_name_to_string, construct_stack, replace_env_vars, + shutdown_stack, validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields @@ -151,18 +152,7 @@ async def shutdown(app): Handled by the lifespan context manager. The shutdown process involves shutting down all implementations registered in the application. """ - for impl in app.__llama_stack_impls__.values(): - impl_name = impl.__class__.__name__ - logger.info("Shutting down %s", impl_name) - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning("No shutdown method for %s", impl_name) - except TimeoutError: - logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) - except (Exception, asyncio.CancelledError) as e: - logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + await shutdown_stack(app.__llama_stack_impls__) @asynccontextmanager diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 57bc4cd5f..0dfd12828 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -93,6 +93,7 @@ RESOURCES = [ REGISTRY_REFRESH_INTERVAL_SECONDS = 300 +REGISTRY_REFRESH_TASK = None async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): @@ -330,7 +331,8 @@ async def construct_stack( await register_resources(run_config, impls) - task = asyncio.create_task(refresh_registry(impls)) + global REGISTRY_REFRESH_TASK + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls)) def cb(task): import traceback @@ -343,10 +345,29 @@ async def construct_stack( else: logger.debug("Model refresh task completed") - task.add_done_callback(cb) + REGISTRY_REFRESH_TASK.add_done_callback(cb) return impls +async def shutdown_stack(impls: dict[Api, Any]): + for impl in impls.values(): + impl_name = impl.__class__.__name__ + logger.info(f"Shutting down {impl_name}") + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning(f"No shutdown method for {impl_name}") + except TimeoutError: + logger.exception(f"Shutdown timeout for {impl_name}") + except (Exception, asyncio.CancelledError) as e: + logger.exception(f"Failed to shutdown {impl_name}: {e}") + + global REGISTRY_REFRESH_TASK + if REGISTRY_REFRESH_TASK: + REGISTRY_REFRESH_TASK.cancel() + + async def refresh_registry(impls: dict[Api, Any]): routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] while True: