feat: Implement async job execution for torchtune training (#1437)

# What does this PR do?

Now a separate thread is started to execute training jobs. Training
requests now return job ID before the job completes. (Which fixes API
timeouts for any jobs that take longer than a minute.)

Note: the scheduler code is meant to be spun out in the future into a
common provider service that can be reused for different APIs and
providers. It is also expected to back the /jobs API proposed here:

https://github.com/meta-llama/llama-stack/discussions/1238

Hence its somewhat generalized form which is expected to simplify its
adoption elsewhere in the future.

Note: this patch doesn't attempt to implement missing APIs (e.g. cancel
or job removal). This work will belong to follow-up PRs.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

Added unit tests for the scheduler module. For the API coverage, did
manual testing and was able to run a training cycle on GPU. The initial
call returned job ID before the training completed, as (now) expected.
Artifacts are returned as expected.

```
JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50')
```

The integration test is currently disabled for the provider. I will look
into how it can be enabled in a different PR / issue context.

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-14 11:59:11 -04:00 committed by GitHub
parent 7641a5cd0b
commit 3ed4316ed5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 472 additions and 39 deletions

View file

@ -3,13 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self) -> None:
await self._scheduler.shutdown()
async def shutdown(self):
pass
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now(timezone.utc)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
except Exception:
job_status_response.status = JobStatus.failed
raise
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
return post_training_job
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None)
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))