llama-stack-mirror/tests/unit/providers/utils/test_scheduler.py
Ihar Hrachyshka a2f054607d
fix: cancel scheduler tasks on shutdown (#2130)
# What does this PR do?

Scheduler: cancel tasks on shutdown.

Otherwise the currently running tasks will never exit (before they
actually complete), which means the process can't be properly shut down
(only with SIGKILL).

Ideally, we let tasks know that they are about to shutdown and give them
some time to do so; but in the lack of the mechanism, it's better to
cancel than linger forever.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Start a long running task (e.g. torchtune or external kfp-provider
training).
Ctr-C the process in TTY. Confirm it exits in reasonable time.

```
^CINFO:     Shutting down
INFO:     Waiting for application shutdown.
13:32:26.187 - INFO - Shutting down
13:32:26.187 - INFO - Shutting down DatasetsRoutingTable
13:32:26.187 - INFO - Shutting down DatasetIORouter
13:32:26.187 - INFO - Shutting down TorchtuneKFPPostTrainingImpl
    Traceback (most recent call last):
      File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 118, in run
        return self._loop.run_until_complete(task)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
        return future.result()
               ^^^^^^^^^^^^^^^
    asyncio.exceptions.CancelledError

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last):
      File "<frozen runpy>", line 198, in _run_module_as_main
      File "<frozen runpy>", line 88, in _run_code
      File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor_main.py", line 109, in <module>
        executor_main()
      File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor_main.py", line 101, in executor_main
        output_file = executor.execute()
                      ^^^^^^^^^^^^^^^^^^
      File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor.py", line 361, in execute
        result = self.func(**func_kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/var/folders/45/1q1rx6cn7jbcn2ty852w0g_r0000gn/T/tmp.RKpPrvTWDD/ephemeral_component.py", line 118, in component
        asyncio.run(recipe.setup())
      File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 194, in run
        return runner.run(main)
               ^^^^^^^^^^^^^^^^
      File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 123, in run
        raise KeyboardInterrupt()
    KeyboardInterrupt


13:32:31.219 - ERROR - Task 'component' finished with status FAILURE
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
INFO     2025-05-09 13:32:31,221 llama_stack.providers.utils.scheduler:221 scheduler: Job
         test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa: Pipeline [1m[95m'test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa'[1m[0m
         finished with status [1m[91mFAILURE[1m[0m. Inner task failed: [1m[96m'component'[1m[0m.
ERROR    2025-05-09 13:32:31,223 llama_stack_provider_kfp_trainer.scheduler:54 scheduler: Job
         test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa failed.
         ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/src/llama_stack_provider_kfp_trainer/scheduler.py:45   │
         │ in do                                                                                                       │
         │                                                                                                             │
         │    42 │   │   │                                                                                             │
         │    43 │   │   │   job.status = JobStatus.running                                                            │
         │    44 │   │   │   try:                                                                                      │
         │ ❱  45 │   │   │   │   artifacts = self._to_artifacts(job.handler().output)                                  │
         │    46 │   │   │   │   for artifact in artifacts:                                                            │
         │    47 │   │   │   │   │   on_artifact_collected_cb(artifact)                                                │
         │    48                                                                                                       │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/base_compon │
         │ ent.py:101 in __call__                                                                                      │
         │                                                                                                             │
         │    98 │   │   │   │   f'{self.name}() missing {len(missing_arguments)} required '                           │
         │    99 │   │   │   │   f'{argument_or_arguments}: {arguments}.')                                             │
         │   100 │   │                                                                                                 │
         │ ❱ 101 │   │   return pipeline_task.PipelineTask(                                                            │
         │   102 │   │   │   component_spec=self.component_spec,                                                       │
         │   103 │   │   │   args=task_inputs,                                                                         │
         │   104 │   │   │   execute_locally=pipeline_context.Pipeline.get_default_pipeline() is                       │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/pipeline_ta │
         │ sk.py:187 in __init__                                                                                       │
         │                                                                                                             │
         │   184 │   │   ])                                                                                            │
         │   185 │   │                                                                                                 │
         │   186 │   │   if execute_locally:                                                                           │
         │ ❱ 187 │   │   │   self._execute_locally(args=args)                                                          │
         │   188 │                                                                                                     │
         │   189 │   def _execute_locally(self, args: Dict[str, Any]) -> None:                                         │
         │   190 │   │   """Execute the pipeline task locally.                                                         │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/pipeline_ta │
         │ sk.py:197 in _execute_locally                                                                               │
         │                                                                                                             │
         │   194 │   │   from kfp.local import task_dispatcher                                                         │
         │   195 │   │                                                                                                 │
         │   196 │   │   if self.pipeline_spec is not None:                                                            │
         │ ❱ 197 │   │   │   self._outputs = pipeline_orchestrator.run_local_pipeline(                                 │
         │   198 │   │   │   │   pipeline_spec=self.pipeline_spec,                                                     │
         │   199 │   │   │   │   arguments=args,                                                                       │
         │   200 │   │   │   )                                                                                         │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │
         │ orchestrator.py:43 in run_local_pipeline                                                                    │
         │                                                                                                             │
         │    40 │                                                                                                     │
         │    41 │   # validate and access all global state in this function, not downstream                           │
         │    42 │   config.LocalExecutionConfig.validate()                                                            │
         │ ❱  43 │   return _run_local_pipeline_implementation(                                                        │
         │    44 │   │   pipeline_spec=pipeline_spec,                                                                  │
         │    45 │   │   arguments=arguments,                                                                          │
         │    46 │   │   raise_on_error=config.LocalExecutionConfig.instance.raise_on_error,                           │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │
         │ orchestrator.py:108 in _run_local_pipeline_implementation                                                   │
         │                                                                                                             │
         │   105 │   │   │   )                                                                                         │
         │   106 │   │   return outputs                                                                                │
         │   107 │   elif dag_status == status.Status.FAILURE:                                                         │
         │ ❱ 108 │   │   log_and_maybe_raise_for_failure(                                                              │
         │   109 │   │   │   pipeline_name=pipeline_name,                                                              │
         │   110 │   │   │   fail_stack=fail_stack,                                                                    │
         │   111 │   │   │   raise_on_error=raise_on_error,                                                            │
         │                                                                                                             │
         │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │
         │ orchestrator.py:137 in log_and_maybe_raise_for_failure                                                      │
         │                                                                                                             │
         │   134 │   │   logging_utils.format_task_name(task_name) for task_name in fail_stack)                        │
         │   135 │   msg = f'Pipeline {pipeline_name_with_color} finished with status                                  │
         │       {status_with_color}. Inner task failed: {task_chain_with_color}.'                                     │
         │   136 │   if raise_on_error:                                                                                │
         │ ❱ 137 │   │   raise RuntimeError(msg)                                                                       │
         │   138 │   with logging_utils.local_logger_context():                                                        │
         │   139 │   │   logging.error(msg)                                                                            │
         │   140                                                                                                       │
         ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         RuntimeError: Pipeline [1m[95m'test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa'[1m[0m finished with status
         [1m[91mFAILURE[1m[0m. Inner task failed: [1m[96m'component'[1m[0m.
INFO     2025-05-09 13:32:31,266 llama_stack.distribution.server.server:136 server: Shutting down
         DistributionInspectImpl
INFO     2025-05-09 13:32:31,266 llama_stack.distribution.server.server:136 server: Shutting down ProviderImpl
INFO:     Application shutdown complete.
INFO:     Finished server process [26648]
```

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-06-19 17:01:33 +02:00

130 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
for _ in range(10):
job = sched.get_job(job_id)
if job.completed_at is not None:
return
await asyncio.sleep(0.1)
raise TimeoutError(f"Job {job_id} did not complete in time.")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# give the job handler a chance to run
await wait_for_job_completed(sched, job_id)
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
await wait_for_job_completed(sched, job_id)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await wait_for_job_completed(sched, job_id)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed