diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 52259258c..b837362d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl: self.datasets_api = datasets # TODO: assume sync job, will need jobs API for async scheduling - self.jobs_status = {} - self.jobs_list = [] + self.jobs = {} self.checkpoints_dict = {} async def supervised_fine_tune( @@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - for job in self.jobs_list: - if job_uuid == job.job_uuid: - raise ValueError(f"Job {job_uuid} already exists") + if job_uuid in self.jobs: + raise ValueError(f"Job {job_uuid} already exists") post_training_job = PostTrainingJob(job_uuid=job_uuid) @@ -65,9 +63,8 @@ class TorchtunePostTrainingImpl: status=JobStatus.scheduled, scheduled_at=datetime.now(), ) - self.jobs_status[job_uuid] = job_status_response + self.jobs[job_uuid] = job_status_response - self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): try: recipe = LoraFinetuningSingleDevice( @@ -114,11 +111,11 @@ class TorchtunePostTrainingImpl: ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse(data=self.jobs_list) + return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - return self.jobs_status.get(job_uuid, None) + return self.jobs.get(job_uuid, None) @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: