fix(registry): ensure clean shutdown

This commit is contained in:
Ashwin Bharambe 2025-07-24 18:53:49 -07:00
parent 3216765c26
commit c480d49917
2 changed files with 25 additions and 14 deletions

View file

@ -56,6 +56,7 @@ from llama_stack.distribution.stack import (
cast_image_name_to_string, cast_image_name_to_string,
construct_stack, construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.config import redact_sensitive_fields
@ -151,18 +152,7 @@ async def shutdown(app):
Handled by the lifespan context manager. The shutdown process involves Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application. shutting down all implementations registered in the application.
""" """
for impl in app.__llama_stack_impls__.values(): await shutdown_stack(app.__llama_stack_impls__)
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager @asynccontextmanager

View file

@ -93,6 +93,7 @@ RESOURCES = [
REGISTRY_REFRESH_INTERVAL_SECONDS = 300 REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
@ -330,7 +331,8 @@ async def construct_stack(
await register_resources(run_config, impls) await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls)) global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
def cb(task): def cb(task):
import traceback import traceback
@ -343,10 +345,29 @@ async def construct_stack(
else: else:
logger.debug("Model refresh task completed") logger.debug("Model refresh task completed")
task.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]): async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True: while True: