diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index b7ee4a219..fddb62570 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect): async def version(self) -> VersionInfo: return VersionInfo(version=version("llama-stack")) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 68fafd8ee..009775ca5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -537,3 +537,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): for tool in tools: await self.unregister_object(tool) await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index d2c32de11..bb735268b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -7,6 +7,7 @@ import argparse import asyncio import functools +import logging import inspect import json import os @@ -52,6 +53,9 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") +logger = logging.getLogger(__name__) + def warn_with_traceback(message, category, filename, lineno, file=None, line=None): log = file if hasattr(file, "write") else sys.stderr @@ -112,21 +116,69 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) -def handle_sigint(app, *args, **kwargs): - print("SIGINT or CTRL-C detected. Exiting gracefully...") +def handle_signal(app, signum, _) -> None: + """ + Handle incoming signals and initiate a graceful shutdown of the application. - async def run_shutdown(): - for impl in app.__llama_stack_impls__.values(): - print(f"Shutting down {impl}") - await impl.shutdown() + This function is intended to be used as a signal handler for various signals + (e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message + indicating the received signal and initiate a shutdown process. - asyncio.run(run_shutdown()) + Args: + app: The application instance containing implementations to be shut down. + signum (int): The signal number received. + frame: The current stack frame (not used in this function). - loop = asyncio.get_event_loop() - for task in asyncio.all_tasks(loop): - task.cancel() + The shutdown process involves: + - Shutting down all implementations registered in the application. + - Gathering all running asyncio tasks. + - Cancelling all gathered tasks. + - Waiting for all tasks to finish. + - Stopping the event loop. - loop.stop() + Note: + This function schedules the shutdown process as an asyncio task and does + not block the current execution. + """ + signame = signal.Signals(signum).name + print(f"Received signal {signame} ({signum}). Exiting gracefully...") + + async def shutdown(): + try: + # Gracefully shut down implementations + for impl in app.__llama_stack_impls__.values(): + impl_name = impl.__class__.__name__ + logger.info("Shutting down %s", impl_name) + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning("No shutdown method for %s", impl_name) + except asyncio.TimeoutError: + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) + except Exception as e: + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + + # Gather all running tasks + loop = asyncio.get_running_loop() + tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()] + + # Cancel all tasks + for task in tasks: + task.cancel() + + # Wait for all tasks to finish + try: + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) + except asyncio.TimeoutError: + logger.exception("Timeout while waiting for tasks to finish") + except asyncio.CancelledError: + pass + finally: + loop.stop() + + loop = asyncio.get_running_loop() + loop.create_task(shutdown()) @asynccontextmanager @@ -386,7 +438,8 @@ def main(): print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - signal.signal(signal.SIGINT, functools.partial(handle_sigint, app)) + signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) + signal.signal(signal.SIGTERM, functools.partial(handle_signal, app)) app.__llama_stack_impls__ = impls diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index fe4ccd1a3..e3c18d112 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -212,3 +212,6 @@ class MetaReferenceAgentsImpl(Agents): async def delete_agent(self, agent_id: str) -> None: await self.persistence_store.delete(f"agent:{agent_id}") + + async def shutdown(self) -> None: + pass diff --git a/pyproject.toml b/pyproject.toml index 5e9cb75e2..2f40ceac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ dev = [ "types-requests", "types-setuptools", "pre-commit", + "uvicorn", + "fastapi", ] docs = [ "sphinx-autobuild", diff --git a/uv.lock b/uv.lock index 087396eea..97ae52124 100644 --- a/uv.lock +++ b/uv.lock @@ -431,6 +431,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a2/b2/5a5dc4affdb6661dea100324e19a7721d5dc524b464fe8e366c093fd7d87/fastapi-0.115.8.tar.gz", hash = "sha256:0ce9111231720190473e222cdf0f07f7206ad7e53ea02beb1d2dc36e2f0741e9", size = 295403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/7d/2d6ce181d7a5f51dedb8c06206cbf0ec026a99bf145edd309f9e17c3282f/fastapi-0.115.8-py3-none-any.whl", hash = "sha256:753a96dd7e036b34eeef8babdfcfe3f28ff79648f86551eb36bfc1b0bf4a8cbf", size = 94814 }, +] + [[package]] name = "fastjsonschema" version = "2.21.1" @@ -724,6 +738,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "black" }, + { name = "fastapi" }, { name = "nbval" }, { name = "pre-commit" }, { name = "pytest" }, @@ -731,6 +746,7 @@ dev = [ { name = "ruff" }, { name = "types-requests" }, { name = "types-setuptools" }, + { name = "uvicorn" }, ] docs = [ { name = "myst-parser" }, @@ -748,6 +764,7 @@ docs = [ requires-dist = [ { name = "black", marker = "extra == 'dev'" }, { name = "blobfile" }, + { name = "fastapi", marker = "extra == 'dev'" }, { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, @@ -776,6 +793,7 @@ requires-dist = [ { name = "termcolor" }, { name = "types-requests", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, + { name = "uvicorn", marker = "extra == 'dev'" }, ] [[package]]