fix(registry): ensure clean shutdown (#2901)
Some checks failed
Coverage Badge / unit-tests (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / discover-tests (push) Successful in 4s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 5s
Python Package Build Test / build (3.13) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 9s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Integration Tests / test-matrix (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 13s
Test Llama Stack Build / build (push) Failing after 4s
Pre-commit / pre-commit (push) Successful in 57s

Avoid the error message:

```
INFO     2025-07-24 21:51:54,530 __main__:598 server: Received interrupt signal, shutting down gracefully...                                          
ERROR    2025-07-24 21:51:54,692 asyncio:1826 uncategorized: Task was destroyed but it is pending!                                                    
         task: <Task pending name='Task-15' coro=<refresh_registry() running at                                                                       
         /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/stack.py:356> wait_for=<Future pending cb=[Task.task_wakeup()]> cb=>  
```
This commit is contained in:
Ashwin Bharambe 2025-07-25 06:44:31 -07:00 committed by GitHub
parent de6919ecdd
commit ed07a58b50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 14 deletions

View file

@ -56,6 +56,7 @@ from llama_stack.distribution.stack import (
cast_image_name_to_string, cast_image_name_to_string,
construct_stack, construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.config import redact_sensitive_fields
@ -151,18 +152,7 @@ async def shutdown(app):
Handled by the lifespan context manager. The shutdown process involves Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application. shutting down all implementations registered in the application.
""" """
for impl in app.__llama_stack_impls__.values(): await shutdown_stack(app.__llama_stack_impls__)
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager @asynccontextmanager

View file

@ -93,6 +93,7 @@ RESOURCES = [
REGISTRY_REFRESH_INTERVAL_SECONDS = 300 REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
@ -330,7 +331,8 @@ async def construct_stack(
await register_resources(run_config, impls) await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls)) global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
def cb(task): def cb(task):
import traceback import traceback
@ -343,10 +345,29 @@ async def construct_stack(
else: else:
logger.debug("Model refresh task completed") logger.debug("Model refresh task completed")
task.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]): async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True: while True: