forked from phoenix-oss/llama-stack-mirror
feat: Implement async job execution for torchtune training (#1437)
# What does this PR do? Now a separate thread is started to execute training jobs. Training requests now return job ID before the job completes. (Which fixes API timeouts for any jobs that take longer than a minute.) Note: the scheduler code is meant to be spun out in the future into a common provider service that can be reused for different APIs and providers. It is also expected to back the /jobs API proposed here: https://github.com/meta-llama/llama-stack/discussions/1238 Hence its somewhat generalized form which is expected to simplify its adoption elsewhere in the future. Note: this patch doesn't attempt to implement missing APIs (e.g. cancel or job removal). This work will belong to follow-up PRs. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Added unit tests for the scheduler module. For the API coverage, did manual testing and was able to run a training cycle on GPU. The initial call returned job ID before the training completed, as (now) expected. Artifacts are returned as expected. ``` JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50') ``` The integration test is currently disabled for the provider. I will look into how it can be enabled in a different PR / issue context. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
7641a5cd0b
commit
3ed4316ed5
3 changed files with 472 additions and 39 deletions
120
tests/unit/providers/utils/test_scheduler.py
Normal file
120
tests/unit/providers/utils/test_scheduler.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_unknown_backend():
|
||||
with pytest.raises(ValueError):
|
||||
Scheduler(backend="unknown")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive():
|
||||
sched = Scheduler()
|
||||
|
||||
# make sure the scheduler starts empty
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_jobs() == []
|
||||
|
||||
called = False
|
||||
|
||||
# schedule a job that will exercise the handlers
|
||||
async def job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
# exercise the handlers
|
||||
on_log("test log1")
|
||||
on_log("test log2")
|
||||
on_artifact({"type": "type1", "path": "path1"})
|
||||
on_artifact({"type": "type2", "path": "path2"})
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, job_handler)
|
||||
|
||||
# make sure the job was properly registered
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_job(job_id) is not None
|
||||
assert sched.get_jobs() == [sched.get_job(job_id)]
|
||||
|
||||
assert sched.get_jobs("unknown") == []
|
||||
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
|
||||
|
||||
# now shut the scheduler down and make sure the job ran
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
assert job.status == JobStatus.completed
|
||||
|
||||
assert job.scheduled_at is not None
|
||||
assert job.started_at is not None
|
||||
assert job.completed_at is not None
|
||||
assert job.scheduled_at < job.started_at < job.completed_at
|
||||
|
||||
assert job.artifacts == [
|
||||
{"type": "type1", "path": "path1"},
|
||||
{"type": "type2", "path": "path2"},
|
||||
]
|
||||
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
|
||||
assert job.logs[0][0] < job.logs[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive_handler_raises():
|
||||
sched = Scheduler()
|
||||
|
||||
async def failing_job_handler(on_log, on_status, on_artifact):
|
||||
on_status(JobStatus.running)
|
||||
raise ValueError("test error")
|
||||
|
||||
job_id = "test_job_id1"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, failing_job_handler)
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
# confirm the exception made the job transition to failed state, even
|
||||
# though it was set to `running` before the error
|
||||
for _ in range(10):
|
||||
if job.status == JobStatus.failed:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
assert job.status == JobStatus.failed
|
||||
|
||||
# confirm that the raised error got registered in log
|
||||
assert job.logs[0][1] == "test error"
|
||||
|
||||
# even after failed job, we can schedule another one
|
||||
called = False
|
||||
|
||||
async def successful_job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id2"
|
||||
sched.schedule(job_type, job_id, successful_job_handler)
|
||||
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
assert job.status == JobStatus.completed
|
Loading…
Add table
Add a link
Reference in a new issue