fix(tests): enable post-training tests

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 21:42:12 +00:00
parent 1341916caf
commit 35dcfff203

View file

@ -5,12 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import os
import sys import sys
import time import time
import uuid import uuid
from datetime import datetime, timezone
import pytest import pytest
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
DataConfig, DataConfig,
LoraFinetuningConfig, LoraFinetuningConfig,
@ -44,6 +47,15 @@ sys.stdout.reconfigure(line_buffering=True)
class TestPostTraining: class TestPostTraining:
job_uuid = f"test-job{uuid.uuid4()}"
model = "ibm-granite/granite-3.3-2b-instruct"
def _validate_checkpoints(self, checkpoints):
assert len(checkpoints) == 1
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
assert checkpoints[0]["epoch"] == 1
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.parametrize( @pytest.mark.parametrize(
"purpose, source", "purpose, source",
@ -92,60 +104,62 @@ class TestPostTraining:
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
) )
job_uuid = f"test-job{uuid.uuid4()}" logger.info(f"Starting training job with UUID: {self.job_uuid}")
logger.info(f"Starting training job with UUID: {job_uuid}")
# train with HF trl SFTTrainer as the default # train with HF trl SFTTrainer as the default
os.makedirs("~/.llama/checkpoints/", exist_ok=True)
started = datetime.now(timezone.utc)
_ = llama_stack_client.post_training.supervised_fine_tune( _ = llama_stack_client.post_training.supervised_fine_tune(
job_uuid=job_uuid, job_uuid=self.job_uuid,
model="ibm-granite/granite-3.3-2b-instruct", model=self.model,
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=training_config, training_config=training_config,
hyperparam_search_config={}, hyperparam_search_config={},
logger_config={}, logger_config={},
checkpoint_dir=None, checkpoint_dir="~/.llama/checkpoints/",
) )
while True: while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
if not status: if not status:
logger.error("Job not found") logger.error("Job not found")
break break
logger.info(f"Current status: {status}") logger.info(f"Current status: {status}")
if status.status == "completed": if status.status == "completed":
completed = datetime.now(timezone.utc)
assert status.completed_at is not None
assert status.completed_at >= started
assert status.completed_at <= completed
break break
logger.info("Waiting for job to complete...") logger.info("Waiting for job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) @pytest.mark.asyncio
logger.info(f"Job artifacts: {artifacts}") def test_get_training_jobs(self, client_with_models):
jobs_list = client_with_models.post_training.job.list()
assert len(jobs_list) == 1
assert jobs_list[0].job_uuid == self.job_uuid
# TODO: Fix these tests to properly represent the Jobs API in training @pytest.mark.asyncio
# @pytest.mark.asyncio def test_get_training_job_status(self, client_with_models):
# async def test_get_training_jobs(self, post_training_stack): job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
# post_training_impl = post_training_stack assert job_status.job_uuid == self.job_uuid
# jobs_list = await post_training_impl.get_training_jobs() assert job_status.status == JobStatus.completed.value
# assert isinstance(jobs_list, list) assert isinstance(job_status.resources_allocated, dict)
# assert jobs_list[0].job_uuid == "1234" self._validate_checkpoints(job_status.checkpoints)
# @pytest.mark.asyncio assert job_status.scheduled_at is not None
# async def test_get_training_job_status(self, post_training_stack): assert job_status.started_at is not None
# post_training_impl = post_training_stack assert job_status.completed_at is not None
# job_status = await post_training_impl.get_training_job_status("1234")
# assert isinstance(job_status, PostTrainingJobStatusResponse)
# assert job_status.job_uuid == "1234"
# assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint)
# @pytest.mark.asyncio assert job_status.scheduled_at <= job_status.started_at
# async def test_get_training_job_artifacts(self, post_training_stack): assert job_status.started_at <= job_status.completed_at
# post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234") @pytest.mark.asyncio
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) def test_get_training_job_artifacts(self, client_with_models):
# assert job_artifacts.job_uuid == "1234" job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint) assert job_artifacts.job_uuid == self.job_uuid
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" self._validate_checkpoints(job_artifacts.checkpoints)
# assert job_artifacts.checkpoints[0].epoch == 0
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path