diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index c77d9305f..b837362d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl: self.datasets_api = datasets # TODO: assume sync job, will need jobs API for async scheduling - self.jobs_status = {} - self.jobs_list = [] + self.jobs = {} self.checkpoints_dict = {} async def supervised_fine_tune( @@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - for job in self.jobs_list: - if job_uuid == job.job_uuid: - raise ValueError(f"Job {job_uuid} already exists") + if job_uuid in self.jobs: + raise ValueError(f"Job {job_uuid} already exists") post_training_job = PostTrainingJob(job_uuid=job_uuid) @@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl: status=JobStatus.scheduled, scheduled_at=datetime.now(), ) + self.jobs[job_uuid] = job_status_response - self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): try: recipe = LoraFinetuningSingleDevice( @@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl: else: raise NotImplementedError() - self.jobs_status[job_uuid] = job_status_response - return post_training_job async def preference_optimize( @@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl: ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse(data=self.jobs_list) + return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - if job_uuid in self.jobs_status: - return self.jobs_status[job_uuid] - return None + return self.jobs.get(job_uuid, None) @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: