llama-stack-mirror/tests/unit/providers/utils/test_scheduler.py
Ihar Hrachyshka 2433ef218d feat: implement async job scheduler for torchtune
Now a separate thread is started to execute training jobs. Training
requests now return job ID before the job completes. (Which fixes API
timeouts for any jobs that take longer than a minute.)

Note: the scheduler code is meant to be spun out in the future into a
common provider service that can be reused for different APIs and
providers. It is also expected to back the /jobs API proposed here:

https://github.com/meta-llama/llama-stack/discussions/1238

Hence its somewhat generalized form which is expected to simplify its
adoption elsewhere in the future.

Note: this patch doesn't attempt to implement missing APIs (e.g. cancel
or job removal). This work will belong to follow-up PRs.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-03-28 12:11:59 -04:00

120 lines
3.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
for _ in range(10):
if job.status == JobStatus.failed:
break
await asyncio.sleep(0.1)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed