forked from phoenix-oss/llama-stack-mirror
fix: miscellaneous job management improvements in torchtune (#1136)
- **refactor: simplify job status extraction a bit** - **torchtune: save job status on schedule** - **refactor: get rid of job_list in torchtune job management code** # What does this PR do? A failed job is now registered in API, and one can consult its status. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` $ llama-stack-client post_training status --job-uuid test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73 JobStatusResponse(checkpoints=[], job_uuid='test-jobe244b5b0-5053-4892-a4d9-d8fc8b116e73', status='failed', completed_at=None, resources_allocated=None, scheduled_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 3252), started_at=datetime.datetime(2025, 2, 18, 9, 4, 34, 10688)) ``` [//]: # (## Documentation) --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
7972daa72e
commit
c1f7d7f005
1 changed files with 6 additions and 12 deletions
|
@ -40,8 +40,7 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs_status = {}
|
||||
self.jobs_list = []
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def supervised_fine_tune(
|
||||
|
@ -54,8 +53,7 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
for job in self.jobs_list:
|
||||
if job_uuid == job.job_uuid:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
@ -65,8 +63,8 @@ class TorchtunePostTrainingImpl:
|
|||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
self.jobs_list.append(post_training_job)
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
|
@ -100,8 +98,6 @@ class TorchtunePostTrainingImpl:
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.jobs_status[job_uuid] = job_status_response
|
||||
|
||||
return post_training_job
|
||||
|
||||
async def preference_optimize(
|
||||
|
@ -115,13 +111,11 @@ class TorchtunePostTrainingImpl:
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
if job_uuid in self.jobs_status:
|
||||
return self.jobs_status[job_uuid]
|
||||
return None
|
||||
return self.jobs.get(job_uuid, None)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue