From c1f7d7f005147967e9e786ed49fb8cf9cac7c8ff Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Wed, 19 Feb 2025 22:09:37 -0500 Subject: [PATCH] fix: miscellaneous job management improvements in torchtune (#1136) - **refactor: simplify job status extraction a bit** - **torchtune: save job status on schedule** - **refactor: get rid of job_list in torchtune job management code** # What does this PR do? A failed job is now registered in API, and one can consult its status. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` $ llama-stack-client post_training status --job-uuid test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73 JobStatusResponse(checkpoints=[], job_uuid='test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73', status='failed', completed_at=None, resources_allocated=None, scheduled_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 3252), started_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 10688)) ``` [//]: # (## Documentation) --------- Signed-off-by: Ihar Hrachyshka --- .../post_training/torchtune/post_training.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index c77d9305f..b837362d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl: self.datasets_api = datasets # TODO: assume sync job, will need jobs API for async scheduling - self.jobs_status = {} - self.jobs_list = [] + self.jobs = {} self.checkpoints_dict = {} async def supervised_fine_tune( @@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - for job in self.jobs_list: - if job_uuid == job.job_uuid: - raise ValueError(f"Job {job_uuid} already exists") + if job_uuid in self.jobs: + raise ValueError(f"Job {job_uuid} already exists") post_training_job = PostTrainingJob(job_uuid=job_uuid) @@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl: status=JobStatus.scheduled, scheduled_at=datetime.now(), ) + self.jobs[job_uuid] = job_status_response - self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): try: recipe = LoraFinetuningSingleDevice( @@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl: else: raise NotImplementedError() - self.jobs_status[job_uuid] = job_status_response - return post_training_job async def preference_optimize( @@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl: ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse(data=self.jobs_list) + return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - if job_uuid in self.jobs_status: - return self.jobs_status[job_uuid] - return None + return self.jobs.get(job_uuid, None) @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: