From e5993c565e6cefac776b71eb04dea5e0371883a1 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 10 Dec 2024 15:24:46 -0800 Subject: [PATCH] misc --- llama_stack/apis/post_training/post_training.py | 2 +- .../inline/post_training/torchtune/post_training.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 62df340b3..235aed783 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -196,4 +196,4 @@ class PostTraining(Protocol): @webmethod(route="/post-training/job/artifacts", method="GET") async def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: ... + ) -> Optional[PostTrainingJobArtifactsResponse]: ... diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index a62f374a0..667940f32 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -111,8 +111,10 @@ class TorchtunePostTrainingImpl: @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: - checkpoints = self.checkpoints_dict.get(job_uuid, []) - return PostTrainingJobArtifactsResponse( - job_uuid=job_uuid, checkpoints=checkpoints - ) + ) -> Optional[PostTrainingJobArtifactsResponse]: + if job_uuid in self.checkpoints_dict: + checkpoints = self.checkpoints_dict.get(job_uuid, []) + return PostTrainingJobArtifactsResponse( + job_uuid=job_uuid, checkpoints=checkpoints + ) + return None