From 214d0645aed1c23c40352a904a8ba2b753c73b8a Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 10 Dec 2024 14:57:03 -0800 Subject: [PATCH] add unit test --- llama_stack/apis/common/training_types.py | 2 +- .../apis/post_training/post_training.py | 12 ++++---- .../torchtune/common/checkpointer.py | 4 ++- .../post_training/torchtune/post_training.py | 4 +++ .../recipes/lora_finetuning_single_device.py | 7 +++-- .../tests/post_training/test_post_training.py | 30 +++++++++++++++++++ 6 files changed, 49 insertions(+), 10 deletions(-) diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index 670877c8f..b4bd1b0c6 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -26,4 +26,4 @@ class Checkpoint(BaseModel): epoch: int post_training_job_id: str path: str - training_metrics: Optional[PostTrainingMetric] + training_metrics: Optional[PostTrainingMetric] = None diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index b77747e3f..62df340b3 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -154,7 +154,7 @@ class PostTrainingJobArtifactsResponse(BaseModel): class PostTraining(Protocol): - @webmethod(route="/post-training/supervised-fine-tune") + @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( self, job_uuid: str, @@ -171,7 +171,7 @@ class PostTraining(Protocol): ] = None, ) -> PostTrainingJob: ... - @webmethod(route="/post-training/preference-optimize") + @webmethod(route="/post-training/preference-optimize", method="POST") async def preference_optimize( self, job_uuid: str, @@ -182,18 +182,18 @@ class PostTraining(Protocol): logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - @webmethod(route="/post-training/jobs") + @webmethod(route="/post-training/jobs", method="GET") async def get_training_jobs(self) -> List[PostTrainingJob]: ... - @webmethod(route="/post-training/job/status") + @webmethod(route="/post-training/job/status", method="GET") async def get_training_job_status( self, job_uuid: str ) -> Optional[PostTrainingJobStatusResponse]: ... - @webmethod(route="/post-training/job/cancel") + @webmethod(route="/post-training/job/cancel", method="POST") async def cancel_training_job(self, job_uuid: str) -> None: ... - @webmethod(route="/post-training/job/artifacts") + @webmethod(route="/post-training/job/artifacts", method="GET") async def get_training_job_artifacts( self, job_uuid: str ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 9642f6ecc..688a03c25 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -152,4 +152,6 @@ class TorchtuneCheckpointer: "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." ) - return model_file_path + print("model_file_path", str(model_file_path)) + + return str(model_file_path) diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 1144d75c5..a62f374a0 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], ) -> PostTrainingJob: + if job_uuid in self.jobs_list: + raise ValueError(f"Job {job_uuid} already exists") + post_training_job = PostTrainingJob(job_uuid=job_uuid) job_status_response = PostTrainingJobStatusResponse( @@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl: try: recipe = LoraFinetuningSingleDevice( self.config, + job_uuid, training_config, hyperparam_search_config, logger_config, diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 5b8b8e30f..734b437b4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice: def __init__( self, config: TorchtunePostTrainingConfig, + job_uuid: str, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], @@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice: algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], datasetio_api: DatasetIO, ) -> None: + self.job_uuid = job_uuid self.training_config = training_config self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device(device="cuda") @@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice: ) return lr_scheduler - async def save_checkpoint(self, epoch: int) -> None: + async def save_checkpoint(self, epoch: int) -> str: ckpt_dict = {} adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) @@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice: } ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) - self._checkpointer.save_checkpoint( + return self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, ) @@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice: identifier=f"{self.model_id}-sft-{curr_epoch}", created_at=datetime.now(), epoch=curr_epoch, + post_training_job_id=self.job_uuid, path=checkpoint_path, ) checkpoints.append(checkpoint) diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index a4e2d55c9..6959f9f9c 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -59,3 +59,33 @@ class TestPostTraining: ) assert isinstance(response, PostTrainingJob) assert response.job_uuid == "1234" + + @pytest.mark.asyncio + async def test_get_training_jobs(self, post_training_stack): + post_training_impl = post_training_stack + jobs_list = await post_training_impl.get_training_jobs() + assert isinstance(jobs_list, List) + assert jobs_list[0].job_uuid == "1234" + + @pytest.mark.asyncio + async def test_get_training_job_status(self, post_training_stack): + post_training_impl = post_training_stack + job_status = await post_training_impl.get_training_job_status("1234") + assert isinstance(job_status, PostTrainingJobStatusResponse) + assert job_status.job_uuid == "1234" + assert job_status.status == JobStatus.completed + assert isinstance(job_status.checkpoints[0], Checkpoint) + + @pytest.mark.asyncio + async def test_get_training_job_artifacts(self, post_training_stack): + post_training_impl = post_training_stack + job_artifacts = await post_training_impl.get_training_job_artifacts("1234") + assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) + assert job_artifacts.job_uuid == "1234" + assert isinstance(job_artifacts.checkpoints[0], Checkpoint) + assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" + assert job_artifacts.checkpoints[0].epoch == 0 + assert ( + "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" + in job_artifacts.checkpoints[0].path + )