mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
add unit test
This commit is contained in:
parent
c9a009b5e7
commit
214d0645ae
6 changed files with 49 additions and 10 deletions
|
@ -26,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metrics: Optional[PostTrainingMetric]
|
training_metrics: Optional[PostTrainingMetric] = None
|
||||||
|
|
|
@ -154,7 +154,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/supervised-fine-tune")
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -171,7 +171,7 @@ class PostTraining(Protocol):
|
||||||
] = None,
|
] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -182,18 +182,18 @@ class PostTraining(Protocol):
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> Optional[PostTrainingJobStatusResponse]: ...
|
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> PostTrainingJobArtifactsResponse: ...
|
||||||
|
|
|
@ -152,4 +152,6 @@ class TorchtuneCheckpointer:
|
||||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_file_path
|
print("model_file_path", str(model_file_path))
|
||||||
|
|
||||||
|
return str(model_file_path)
|
||||||
|
|
|
@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl:
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
if job_uuid in self.jobs_list:
|
||||||
|
raise ValueError(f"Job {job_uuid} already exists")
|
||||||
|
|
||||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
job_status_response = PostTrainingJobStatusResponse(
|
job_status_response = PostTrainingJobStatusResponse(
|
||||||
|
@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl:
|
||||||
try:
|
try:
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe = LoraFinetuningSingleDevice(
|
||||||
self.config,
|
self.config,
|
||||||
|
job_uuid,
|
||||||
training_config,
|
training_config,
|
||||||
hyperparam_search_config,
|
hyperparam_search_config,
|
||||||
logger_config,
|
logger_config,
|
||||||
|
|
|
@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
|
@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice:
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.job_uuid = job_uuid
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
|
@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
async def save_checkpoint(self, epoch: int) -> None:
|
async def save_checkpoint(self, epoch: int) -> str:
|
||||||
ckpt_dict = {}
|
ckpt_dict = {}
|
||||||
|
|
||||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||||
|
@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice:
|
||||||
}
|
}
|
||||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||||
|
|
||||||
self._checkpointer.save_checkpoint(
|
return self._checkpointer.save_checkpoint(
|
||||||
ckpt_dict,
|
ckpt_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
)
|
)
|
||||||
|
@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice:
|
||||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
epoch=curr_epoch,
|
epoch=curr_epoch,
|
||||||
|
post_training_job_id=self.job_uuid,
|
||||||
path=checkpoint_path,
|
path=checkpoint_path,
|
||||||
)
|
)
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
|
@ -59,3 +59,33 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
assert isinstance(response, PostTrainingJob)
|
assert isinstance(response, PostTrainingJob)
|
||||||
assert response.job_uuid == "1234"
|
assert response.job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_jobs(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
jobs_list = await post_training_impl.get_training_jobs()
|
||||||
|
assert isinstance(jobs_list, List)
|
||||||
|
assert jobs_list[0].job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_status(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_status = await post_training_impl.get_training_job_status("1234")
|
||||||
|
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||||
|
assert job_status.job_uuid == "1234"
|
||||||
|
assert job_status.status == JobStatus.completed
|
||||||
|
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||||
|
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||||
|
assert job_artifacts.job_uuid == "1234"
|
||||||
|
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||||
|
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||||
|
assert job_artifacts.checkpoints[0].epoch == 0
|
||||||
|
assert (
|
||||||
|
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||||
|
in job_artifacts.checkpoints[0].path
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue