mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 11:39:47 +00:00
add unit test
This commit is contained in:
parent
c9a009b5e7
commit
214d0645ae
6 changed files with 49 additions and 10 deletions
|
|
@ -152,4 +152,6 @@ class TorchtuneCheckpointer:
|
|||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
return model_file_path
|
||||
print("model_file_path", str(model_file_path))
|
||||
|
||||
return str(model_file_path)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs_list:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
|
|
@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl:
|
|||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice:
|
|||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
|
|
@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice:
|
|||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
|
|
@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
return lr_scheduler
|
||||
|
||||
async def save_checkpoint(self, epoch: int) -> None:
|
||||
async def save_checkpoint(self, epoch: int) -> str:
|
||||
ckpt_dict = {}
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
|
|
@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice:
|
|||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
self._checkpointer.save_checkpoint(
|
||||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
|
@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice:
|
|||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue