add unit test

This commit is contained in:
Botao Chen 2024-12-10 14:57:03 -08:00
parent c9a009b5e7
commit 214d0645ae
6 changed files with 49 additions and 10 deletions

View file

@ -152,4 +152,6 @@ class TorchtuneCheckpointer:
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
return model_file_path
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob:
if job_uuid in self.jobs_list:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl:
try:
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,

View file

@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice:
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice:
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO,
) -> None:
self.job_uuid = job_uuid
self.training_config = training_config
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice:
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None:
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice:
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice:
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
checkpoints.append(checkpoint)

View file

@ -59,3 +59,33 @@ class TestPostTraining:
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
assert job_artifacts.checkpoints[0].epoch == 0
assert (
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
in job_artifacts.checkpoints[0].path
)