This commit is contained in:
Botao Chen 2024-12-10 15:24:46 -08:00
parent 214d0645ae
commit e5993c565e
2 changed files with 8 additions and 6 deletions

View file

@ -196,4 +196,4 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -111,8 +111,10 @@ class TorchtunePostTrainingImpl:
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return None