diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 085ddaafa..2f4578a86 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -157,10 +157,14 @@ class _NaiveSchedulerBackend(_SchedulerBackend): asyncio.set_event_loop(self._loop) self._loop.run_forever() - # When stopping the loop, give tasks a chance to finish + # TODO: When stopping the loop, give tasks a chance to finish # TODO: should we explicitly inform jobs of pending stoppage? + + # cancel all tasks for task in asyncio.all_tasks(self._loop): - self._loop.run_until_complete(task) + if not task.done(): + task.cancel() + self._loop.close() async def shutdown(self) -> None: diff --git a/tests/unit/providers/utils/test_scheduler.py b/tests/unit/providers/utils/test_scheduler.py index 76f0da8ce..25b4935de 100644 --- a/tests/unit/providers/utils/test_scheduler.py +++ b/tests/unit/providers/utils/test_scheduler.py @@ -17,6 +17,15 @@ async def test_scheduler_unknown_backend(): Scheduler(backend="unknown") +async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None: + for _ in range(10): + job = sched.get_job(job_id) + if job.completed_at is not None: + return + await asyncio.sleep(0.1) + raise TimeoutError(f"Job {job_id} did not complete in time.") + + @pytest.mark.asyncio async def test_scheduler_naive(): sched = Scheduler() @@ -52,6 +61,9 @@ async def test_scheduler_naive(): assert sched.get_jobs("unknown") == [] assert sched.get_jobs(job_type) == [sched.get_job(job_id)] + # give the job handler a chance to run + await wait_for_job_completed(sched, job_id) + # now shut the scheduler down and make sure the job ran await sched.shutdown() @@ -92,10 +104,7 @@ async def test_scheduler_naive_handler_raises(): # confirm the exception made the job transition to failed state, even # though it was set to `running` before the error - for _ in range(10): - if job.status == JobStatus.failed: - break - await asyncio.sleep(0.1) + await wait_for_job_completed(sched, job_id) assert job.status == JobStatus.failed # confirm that the raised error got registered in log @@ -111,6 +120,7 @@ async def test_scheduler_naive_handler_raises(): job_id = "test_job_id2" sched.schedule(job_type, job_id, successful_job_handler) + await wait_for_job_completed(sched, job_id) await sched.shutdown()