From 442f85ef4af6040f32744b7a58b6679ad3a7e760 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 18 Feb 2025 09:45:12 -0500 Subject: [PATCH] refactor: get rid of job_list in torchtune job management code There's already jobs_status that has the same info. (Renamed it into self.jobs.) Signed-off-by: Ihar Hrachyshka --- .../post_training/torchtune/post_training.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 52259258c..b837362d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl: self.datasets_api = datasets # TODO: assume sync job, will need jobs API for async scheduling - self.jobs_status = {} - self.jobs_list = [] + self.jobs = {} self.checkpoints_dict = {} async def supervised_fine_tune( @@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - for job in self.jobs_list: - if job_uuid == job.job_uuid: - raise ValueError(f"Job {job_uuid} already exists") + if job_uuid in self.jobs: + raise ValueError(f"Job {job_uuid} already exists") post_training_job = PostTrainingJob(job_uuid=job_uuid) @@ -65,9 +63,8 @@ class TorchtunePostTrainingImpl: status=JobStatus.scheduled, scheduled_at=datetime.now(), ) - self.jobs_status[job_uuid] = job_status_response + self.jobs[job_uuid] = job_status_response - self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): try: recipe = LoraFinetuningSingleDevice( @@ -114,11 +111,11 @@ class TorchtunePostTrainingImpl: ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse(data=self.jobs_list) + return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - return self.jobs_status.get(job_uuid, None) + return self.jobs.get(job_uuid, None) @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: