fix(registry): ensure clean shutdown

This commit is contained in:
Ashwin Bharambe 2025-07-24 18:53:49 -07:00
parent 3216765c26
commit c480d49917
2 changed files with 25 additions and 14 deletions

View file

@ -56,6 +56,7 @@ from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
shutdown_stack,
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
@ -151,18 +152,7 @@ async def shutdown(app):
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
await shutdown_stack(app.__llama_stack_impls__)
@asynccontextmanager

View file

@ -93,6 +93,7 @@ RESOURCES = [
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
@ -330,7 +331,8 @@ async def construct_stack(
await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls))
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
def cb(task):
import traceback
@ -343,10 +345,29 @@ async def construct_stack(
else:
logger.debug("Model refresh task completed")
task.add_done_callback(cb)
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True: