mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
add unit test
This commit is contained in:
parent
c9a009b5e7
commit
214d0645ae
6 changed files with 49 additions and 10 deletions
|
@ -26,4 +26,4 @@ class Checkpoint(BaseModel):
|
|||
epoch: int
|
||||
post_training_job_id: str
|
||||
path: str
|
||||
training_metrics: Optional[PostTrainingMetric]
|
||||
training_metrics: Optional[PostTrainingMetric] = None
|
||||
|
|
|
@ -154,7 +154,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -171,7 +171,7 @@ class PostTraining(Protocol):
|
|||
] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -182,18 +182,18 @@ class PostTraining(Protocol):
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs")
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
@ -152,4 +152,6 @@ class TorchtuneCheckpointer:
|
|||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
return model_file_path
|
||||
print("model_file_path", str(model_file_path))
|
||||
|
||||
return str(model_file_path)
|
||||
|
|
|
@ -35,6 +35,9 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs_list:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
|
@ -48,6 +51,7 @@ class TorchtunePostTrainingImpl:
|
|||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
|
|
|
@ -68,6 +68,7 @@ class LoraFinetuningSingleDevice:
|
|||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
|
@ -76,6 +77,7 @@ class LoraFinetuningSingleDevice:
|
|||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
|
@ -366,7 +368,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
return lr_scheduler
|
||||
|
||||
async def save_checkpoint(self, epoch: int) -> None:
|
||||
async def save_checkpoint(self, epoch: int) -> str:
|
||||
ckpt_dict = {}
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
|
@ -396,7 +398,7 @@ class LoraFinetuningSingleDevice:
|
|||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
self._checkpointer.save_checkpoint(
|
||||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
@ -514,6 +516,7 @@ class LoraFinetuningSingleDevice:
|
|||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
|
|
@ -59,3 +59,33 @@ class TestPostTraining:
|
|||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, List)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert (
|
||||
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||
in job_artifacts.checkpoints[0].path
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue