fix(tests): enable post-training tests

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 21:42:12 +00:00
parent 1341916caf
commit 35dcfff203

View file

@ -5,12 +5,15 @@
# the root directory of this source tree.
import logging
import os
import sys
import time
import uuid
from datetime import datetime, timezone
import pytest
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
DataConfig,
LoraFinetuningConfig,
@ -44,6 +47,15 @@ sys.stdout.reconfigure(line_buffering=True)
class TestPostTraining:
job_uuid = f"test-job{uuid.uuid4()}"
model = "ibm-granite/granite-3.3-2b-instruct"
def _validate_checkpoints(self, checkpoints):
assert len(checkpoints) == 1
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
assert checkpoints[0]["epoch"] == 1
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]
@pytest.mark.integration
@pytest.mark.parametrize(
"purpose, source",
@ -92,60 +104,62 @@ class TestPostTraining:
gradient_accumulation_steps=1,
)
job_uuid = f"test-job{uuid.uuid4()}"
logger.info(f"Starting training job with UUID: {job_uuid}")
logger.info(f"Starting training job with UUID: {self.job_uuid}")
# train with HF trl SFTTrainer as the default
os.makedirs("~/.llama/checkpoints/", exist_ok=True)
started = datetime.now(timezone.utc)
_ = llama_stack_client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model="ibm-granite/granite-3.3-2b-instruct",
job_uuid=self.job_uuid,
model=self.model,
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
checkpoint_dir=None,
checkpoint_dir="~/.llama/checkpoints/",
)
while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
if not status:
logger.error("Job not found")
break
logger.info(f"Current status: {status}")
if status.status == "completed":
completed = datetime.now(timezone.utc)
assert status.completed_at is not None
assert status.completed_at >= started
assert status.completed_at <= completed
break
logger.info("Waiting for job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"Job artifacts: {artifacts}")
@pytest.mark.asyncio
def test_get_training_jobs(self, client_with_models):
jobs_list = client_with_models.post_training.job.list()
assert len(jobs_list) == 1
assert jobs_list[0].job_uuid == self.job_uuid
# TODO: Fix these tests to properly represent the Jobs API in training
# @pytest.mark.asyncio
# async def test_get_training_jobs(self, post_training_stack):
# post_training_impl = post_training_stack
# jobs_list = await post_training_impl.get_training_jobs()
# assert isinstance(jobs_list, list)
# assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
def test_get_training_job_status(self, client_with_models):
job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
assert job_status.job_uuid == self.job_uuid
assert job_status.status == JobStatus.completed.value
assert isinstance(job_status.resources_allocated, dict)
self._validate_checkpoints(job_status.checkpoints)
# @pytest.mark.asyncio
# async def test_get_training_job_status(self, post_training_stack):
# post_training_impl = post_training_stack
# job_status = await post_training_impl.get_training_job_status("1234")
# assert isinstance(job_status, PostTrainingJobStatusResponse)
# assert job_status.job_uuid == "1234"
# assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint)
assert job_status.scheduled_at is not None
assert job_status.started_at is not None
assert job_status.completed_at is not None
# @pytest.mark.asyncio
# async def test_get_training_job_artifacts(self, post_training_stack):
# post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
# assert job_artifacts.job_uuid == "1234"
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
# assert job_artifacts.checkpoints[0].epoch == 0
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
assert job_status.scheduled_at <= job_status.started_at
assert job_status.started_at <= job_status.completed_at
@pytest.mark.asyncio
def test_get_training_job_artifacts(self, client_with_models):
job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
assert job_artifacts.job_uuid == self.job_uuid
self._validate_checkpoints(job_artifacts.checkpoints)