feat: Implement async job execution for torchtune training (#1437)

# What does this PR do?

Now a separate thread is started to execute training jobs. Training
requests now return job ID before the job completes. (Which fixes API
timeouts for any jobs that take longer than a minute.)

Note: the scheduler code is meant to be spun out in the future into a
common provider service that can be reused for different APIs and
providers. It is also expected to back the /jobs API proposed here:

https://github.com/meta-llama/llama-stack/discussions/1238

Hence its somewhat generalized form which is expected to simplify its
adoption elsewhere in the future.

Note: this patch doesn't attempt to implement missing APIs (e.g. cancel
or job removal). This work will belong to follow-up PRs.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

Added unit tests for the scheduler module. For the API coverage, did
manual testing and was able to run a training cycle on GPU. The initial
call returned job ID before the training completed, as (now) expected.
Artifacts are returned as expected.

```
JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50')
```

The integration test is currently disabled for the provider. I will look
into how it can be enabled in a different PR / issue context.

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-14 11:59:11 -04:00 committed by GitHub
parent 7641a5cd0b
commit 3ed4316ed5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 472 additions and 39 deletions

View file

@ -3,13 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self) -> None:
await self._scheduler.shutdown()
async def shutdown(self):
pass
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now(timezone.utc)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
except Exception:
job_status_response.status = JobStatus.failed
raise
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
return post_training_job
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None)
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: Dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = Tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
for _ in range(10):
if job.status == JobStatus.failed:
break
await asyncio.sleep(0.1)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed