add unit test

This commit is contained in:
Botao Chen 2024-12-10 14:57:03 -08:00
parent c9a009b5e7
commit 214d0645ae
6 changed files with 49 additions and 10 deletions

View file

@ -26,4 +26,4 @@ class Checkpoint(BaseModel):
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric]
training_metrics: Optional[PostTrainingMetric] = None

View file

@ -154,7 +154,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
job_uuid: str,
@ -171,7 +171,7 @@ class PostTraining(Protocol):
] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
@ -182,18 +182,18 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
@webmethod(route="/post-training/job/status")
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel")
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -152,4 +152,6 @@ class TorchtuneCheckpointer:
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
return model_file_path
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob:
if job_uuid in self.jobs_list:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl:
try:
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,

View file

@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice:
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice:
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO,
) -> None:
self.job_uuid = job_uuid
self.training_config = training_config
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice:
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None:
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice:
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice:
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
checkpoints.append(checkpoint)

View file

@ -59,3 +59,33 @@ class TestPostTraining:
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
assert job_artifacts.checkpoints[0].epoch == 0
assert (
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
in job_artifacts.checkpoints[0].path
)