mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
refactor: get rid of job_list in torchtune job management code
There's already jobs_status that has the same info. (Renamed it into self.jobs.) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
07a1d44f4c
commit
442f85ef4a
1 changed files with 6 additions and 9 deletions
|
@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl:
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
|
||||||
# TODO: assume sync job, will need jobs API for async scheduling
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
self.jobs_status = {}
|
self.jobs = {}
|
||||||
self.jobs_list = []
|
|
||||||
self.checkpoints_dict = {}
|
self.checkpoints_dict = {}
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
|
@ -54,9 +53,8 @@ class TorchtunePostTrainingImpl:
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[AlgorithmConfig],
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
for job in self.jobs_list:
|
if job_uuid in self.jobs:
|
||||||
if job_uuid == job.job_uuid:
|
raise ValueError(f"Job {job_uuid} already exists")
|
||||||
raise ValueError(f"Job {job_uuid} already exists")
|
|
||||||
|
|
||||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
@ -65,9 +63,8 @@ class TorchtunePostTrainingImpl:
|
||||||
status=JobStatus.scheduled,
|
status=JobStatus.scheduled,
|
||||||
scheduled_at=datetime.now(),
|
scheduled_at=datetime.now(),
|
||||||
)
|
)
|
||||||
self.jobs_status[job_uuid] = job_status_response
|
self.jobs[job_uuid] = job_status_response
|
||||||
|
|
||||||
self.jobs_list.append(post_training_job)
|
|
||||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
try:
|
try:
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe = LoraFinetuningSingleDevice(
|
||||||
|
@ -114,11 +111,11 @@ class TorchtunePostTrainingImpl:
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
return self.jobs_status.get(job_uuid, None)
|
return self.jobs.get(job_uuid, None)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue