fix: miscellaneous job management improvements in torchtune (#1136)

- **refactor: simplify job status extraction a bit**
- **torchtune: save job status on schedule**
- **refactor: get rid of job_list in torchtune job management code**

# What does this PR do?

A failed job is now registered in API, and one can consult its status.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

```
$ llama-stack-client post_training status --job-uuid test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73                                                      
JobStatusResponse(checkpoints=[], job_uuid='test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73', status='failed', completed_at=None, resources_allocated=None, scheduled_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 3252), started_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 10688))
```

[//]: # (## Documentation)

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-02-19 22:09:37 -05:00 committed by GitHub
parent 7972daa72e
commit c1f7d7f005
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl:
self.datasets_api = datasets self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling # TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {} self.jobs = {}
self.jobs_list = []
self.checkpoints_dict = {} self.checkpoints_dict = {}
async def supervised_fine_tune( async def supervised_fine_tune(
@ -54,8 +53,7 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig], algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob: ) -> PostTrainingJob:
for job in self.jobs_list: if job_uuid in self.jobs:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists") raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid) post_training_job = PostTrainingJob(job_uuid=job_uuid)
@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl:
status=JobStatus.scheduled, status=JobStatus.scheduled,
scheduled_at=datetime.now(), scheduled_at=datetime.now(),
) )
self.jobs[job_uuid] = job_status_response
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig): if isinstance(algorithm_config, LoraFinetuningConfig):
try: try:
recipe = LoraFinetuningSingleDevice( recipe = LoraFinetuningSingleDevice(
@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl:
else: else:
raise NotImplementedError() raise NotImplementedError()
self.jobs_status[job_uuid] = job_status_response
return post_training_job return post_training_job
async def preference_optimize( async def preference_optimize(
@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=self.jobs_list) return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status: return self.jobs.get(job_uuid, None)
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: async def cancel_training_job(self, job_uuid: str) -> None: