mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix(tests): enable post-training tests
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
1341916caf
commit
35dcfff203
1 changed files with 47 additions and 33 deletions
|
@ -5,12 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
DataConfig,
|
DataConfig,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
|
@ -44,6 +47,15 @@ sys.stdout.reconfigure(line_buffering=True)
|
||||||
|
|
||||||
|
|
||||||
class TestPostTraining:
|
class TestPostTraining:
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
model = "ibm-granite/granite-3.3-2b-instruct"
|
||||||
|
|
||||||
|
def _validate_checkpoints(self, checkpoints):
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0]["identifier"] == f"{self.model}-sft-1"
|
||||||
|
assert checkpoints[0]["epoch"] == 1
|
||||||
|
assert "/.llama/checkpoints/merged_model" in checkpoints[0]["path"]
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"purpose, source",
|
"purpose, source",
|
||||||
|
@ -92,60 +104,62 @@ class TestPostTraining:
|
||||||
gradient_accumulation_steps=1,
|
gradient_accumulation_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
logger.info(f"Starting training job with UUID: {self.job_uuid}")
|
||||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
|
||||||
|
|
||||||
# train with HF trl SFTTrainer as the default
|
# train with HF trl SFTTrainer as the default
|
||||||
|
os.makedirs("~/.llama/checkpoints/", exist_ok=True)
|
||||||
|
|
||||||
|
started = datetime.now(timezone.utc)
|
||||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||||
job_uuid=job_uuid,
|
job_uuid=self.job_uuid,
|
||||||
model="ibm-granite/granite-3.3-2b-instruct",
|
model=self.model,
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=training_config,
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
logger_config={},
|
logger_config={},
|
||||||
checkpoint_dir=None,
|
checkpoint_dir="~/.llama/checkpoints/",
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
status = llama_stack_client.post_training.job.status(job_uuid=self.job_uuid)
|
||||||
if not status:
|
if not status:
|
||||||
logger.error("Job not found")
|
logger.error("Job not found")
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Current status: {status}")
|
logger.info(f"Current status: {status}")
|
||||||
if status.status == "completed":
|
if status.status == "completed":
|
||||||
|
completed = datetime.now(timezone.utc)
|
||||||
|
assert status.completed_at is not None
|
||||||
|
assert status.completed_at >= started
|
||||||
|
assert status.completed_at <= completed
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info("Waiting for job to complete...")
|
logger.info("Waiting for job to complete...")
|
||||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||||
|
|
||||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
@pytest.mark.asyncio
|
||||||
logger.info(f"Job artifacts: {artifacts}")
|
def test_get_training_jobs(self, client_with_models):
|
||||||
|
jobs_list = client_with_models.post_training.job.list()
|
||||||
|
assert len(jobs_list) == 1
|
||||||
|
assert jobs_list[0].job_uuid == self.job_uuid
|
||||||
|
|
||||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
@pytest.mark.asyncio
|
||||||
# @pytest.mark.asyncio
|
def test_get_training_job_status(self, client_with_models):
|
||||||
# async def test_get_training_jobs(self, post_training_stack):
|
job_status = client_with_models.post_training.job.status(job_uuid=self.job_uuid)
|
||||||
# post_training_impl = post_training_stack
|
assert job_status.job_uuid == self.job_uuid
|
||||||
# jobs_list = await post_training_impl.get_training_jobs()
|
assert job_status.status == JobStatus.completed.value
|
||||||
# assert isinstance(jobs_list, list)
|
assert isinstance(job_status.resources_allocated, dict)
|
||||||
# assert jobs_list[0].job_uuid == "1234"
|
self._validate_checkpoints(job_status.checkpoints)
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
assert job_status.scheduled_at is not None
|
||||||
# async def test_get_training_job_status(self, post_training_stack):
|
assert job_status.started_at is not None
|
||||||
# post_training_impl = post_training_stack
|
assert job_status.completed_at is not None
|
||||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
|
||||||
# assert isinstance(job_status, PostTrainingJobStatusResponse)
|
|
||||||
# assert job_status.job_uuid == "1234"
|
|
||||||
# assert job_status.status == JobStatus.completed
|
|
||||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
assert job_status.scheduled_at <= job_status.started_at
|
||||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
assert job_status.started_at <= job_status.completed_at
|
||||||
# post_training_impl = post_training_stack
|
|
||||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
@pytest.mark.asyncio
|
||||||
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
def test_get_training_job_artifacts(self, client_with_models):
|
||||||
# assert job_artifacts.job_uuid == "1234"
|
job_artifacts = client_with_models.post_training.job.artifacts(job_uuid=self.job_uuid)
|
||||||
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
assert job_artifacts.job_uuid == self.job_uuid
|
||||||
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
|
self._validate_checkpoints(job_artifacts.checkpoints)
|
||||||
# assert job_artifacts.checkpoints[0].epoch == 0
|
|
||||||
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue