mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
fix(tests): enable post-training tests
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
1341916caf
commit
35dcfff203
1 changed files with 47 additions and 33 deletions
|
@ -5,12 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
|
@ -44,6 +47,15 @@ sys.stdout.reconfigure(line_buffering=True)
|
|||
|
||||
|
||||
class TestPostTraining:
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
model = "ibm-granite/granite-3.3-2b-instruct"
|
||||
|
||||
def _validate_checkpoints(self, checkpoints):
|
||||
assert len(checkpoints) == 1
|
||||
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
|
||||
assert checkpoints[0]["epoch"] == 1
|
||||
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source",
|
||||
|
@ -92,60 +104,62 @@ class TestPostTraining:
|
|||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
||||
logger.info(f"Starting training job with UUID: {self.job_uuid}")
|
||||
|
||||
# train with HF trl SFTTrainer as the default
|
||||
os.makedirs("~/.llama/checkpoints/", exist_ok=True)
|
||||
|
||||
started = datetime.now(timezone.utc)
|
||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=job_uuid,
|
||||
model="ibm-granite/granite-3.3-2b-instruct",
|
||||
job_uuid=self.job_uuid,
|
||||
model=self.model,
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir=None,
|
||||
checkpoint_dir="~/.llama/checkpoints/",
|
||||
)
|
||||
|
||||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
|
||||
if not status:
|
||||
logger.error("Job not found")
|
||||
break
|
||||
|
||||
logger.info(f"Current status: {status}")
|
||||
if status.status == "completed":
|
||||
completed = datetime.now(timezone.utc)
|
||||
assert status.completed_at is not None
|
||||
assert status.completed_at >= started
|
||||
assert status.completed_at <= completed
|
||||
break
|
||||
|
||||
logger.info("Waiting for job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"Job artifacts: {artifacts}")
|
||||
@pytest.mark.asyncio
|
||||
def test_get_training_jobs(self, client_with_models):
|
||||
jobs_list = client_with_models.post_training.job.list()
|
||||
assert len(jobs_list) == 1
|
||||
assert jobs_list[0].job_uuid == self.job_uuid
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_jobs(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# jobs_list = await post_training_impl.get_training_jobs()
|
||||
# assert isinstance(jobs_list, list)
|
||||
# assert jobs_list[0].job_uuid == "1234"
|
||||
@pytest.mark.asyncio
|
||||
def test_get_training_job_status(self, client_with_models):
|
||||
job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
|
||||
assert job_status.job_uuid == self.job_uuid
|
||||
assert job_status.status == JobStatus.completed.value
|
||||
assert isinstance(job_status.resources_allocated, dict)
|
||||
self._validate_checkpoints(job_status.checkpoints)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_status(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
||||
# assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
# assert job_status.job_uuid == "1234"
|
||||
# assert job_status.status == JobStatus.completed
|
||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
assert job_status.scheduled_at is not None
|
||||
assert job_status.started_at is not None
|
||||
assert job_status.completed_at is not None
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
# assert job_artifacts.job_uuid == "1234"
|
||||
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
|
||||
# assert job_artifacts.checkpoints[0].epoch == 0
|
||||
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
||||
assert job_status.scheduled_at <= job_status.started_at
|
||||
assert job_status.started_at <= job_status.completed_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_get_training_job_artifacts(self, client_with_models):
|
||||
job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
|
||||
assert job_artifacts.job_uuid == self.job_uuid
|
||||
self._validate_checkpoints(job_artifacts.checkpoints)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue