diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index bb4639d17..9f0d887a7 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -5,12 +5,15 @@ # the root directory of this source tree. import logging +import os import sys import time import uuid +from datetime import datetime, timezone import pytest +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( DataConfig, LoraFinetuningConfig, @@ -44,6 +47,15 @@ sys.stdout.reconfigure(line_buffering=True) class TestPostTraining: + job_uuid = f"test-job{uuid.uuid4()}" + model = "ibm-granite/granite-3.3-2b-instruct" + + def _validate_checkpoints(self, checkpoints): + assert len(checkpoints) == 1 + assert checkpoints[0]["identifier"] == f"{self.model}-sft-1" + assert checkpoints[0]["epoch"] == 1 + assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"] + @pytest.mark.integration @pytest.mark.parametrize( "purpose, source", @@ -92,60 +104,62 @@ class TestPostTraining: gradient_accumulation_steps=1, ) - job_uuid = f"test-job{uuid.uuid4()}" - logger.info(f"Starting training job with UUID: {job_uuid}") + logger.info(f"Starting training job with UUID: {self.job_uuid}") # train with HF trl SFTTrainer as the default + os.makedirs("~/.llama/checkpoints/", exist_ok=True) + + started = datetime.now(timezone.utc) _ = llama_stack_client.post_training.supervised_fine_tune( - job_uuid=job_uuid, - model="ibm-granite/granite-3.3-2b-instruct", + job_uuid=self.job_uuid, + model=self.model, algorithm_config=algorithm_config, training_config=training_config, hyperparam_search_config={}, logger_config={}, - checkpoint_dir=None, + checkpoint_dir="~/.llama/checkpoints/", ) while True: - status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) + status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid) if not status: logger.error("Job not found") break logger.info(f"Current status: {status}") if status.status == "completed": + completed = datetime.now(timezone.utc) + assert status.completed_at is not None + assert status.completed_at >= started + assert status.completed_at <= completed break logger.info("Waiting for job to complete...") time.sleep(10) # Increased sleep time to reduce polling frequency - artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) - logger.info(f"Job artifacts: {artifacts}") + @pytest.mark.asyncio + def test_get_training_jobs(self, client_with_models): + jobs_list = client_with_models.post_training.job.list() + assert len(jobs_list) == 1 + assert jobs_list[0].job_uuid == self.job_uuid - # TODO: Fix these tests to properly represent the Jobs API in training - # @pytest.mark.asyncio - # async def test_get_training_jobs(self, post_training_stack): - # post_training_impl = post_training_stack - # jobs_list = await post_training_impl.get_training_jobs() - # assert isinstance(jobs_list, list) - # assert jobs_list[0].job_uuid == "1234" + @pytest.mark.asyncio + def test_get_training_job_status(self, client_with_models): + job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid) + assert job_status.job_uuid == self.job_uuid + assert job_status.status == JobStatus.completed.value + assert isinstance(job_status.resources_allocated, dict) + self._validate_checkpoints(job_status.checkpoints) - # @pytest.mark.asyncio - # async def test_get_training_job_status(self, post_training_stack): - # post_training_impl = post_training_stack - # job_status = await post_training_impl.get_training_job_status("1234") - # assert isinstance(job_status, PostTrainingJobStatusResponse) - # assert job_status.job_uuid == "1234" - # assert job_status.status == JobStatus.completed - # assert isinstance(job_status.checkpoints[0], Checkpoint) + assert job_status.scheduled_at is not None + assert job_status.started_at is not None + assert job_status.completed_at is not None - # @pytest.mark.asyncio - # async def test_get_training_job_artifacts(self, post_training_stack): - # post_training_impl = post_training_stack - # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") - # assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) - # assert job_artifacts.job_uuid == "1234" - # assert isinstance(job_artifacts.checkpoints[0], Checkpoint) - # assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" - # assert job_artifacts.checkpoints[0].epoch == 0 - # assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path + assert job_status.scheduled_at <= job_status.started_at + assert job_status.started_at <= job_status.completed_at + + @pytest.mark.asyncio + def test_get_training_job_artifacts(self, client_with_models): + job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid) + assert job_artifacts.job_uuid == self.job_uuid + self._validate_checkpoints(job_artifacts.checkpoints)