mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 04:02:00 +00:00
feat: add integration tests for post_training
set inline::huggingface as the default post_training provider for the ollama distribution and add integration tests for post_training Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
7dcb997f17
commit
ff246d890a
10 changed files with 161 additions and 53 deletions
|
|
@ -4,20 +4,38 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def capture_output(capsys):
|
||||
"""Fixture to capture and display output during test execution."""
|
||||
yield
|
||||
captured = capsys.readouterr()
|
||||
if captured.out:
|
||||
print("\nCaptured stdout:", captured.out)
|
||||
if captured.err:
|
||||
print("\nCaptured stderr:", captured.err)
|
||||
|
||||
|
||||
# Force flush stdout to see prints immediately
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
|
|
@ -25,10 +43,31 @@ from llama_stack.apis.post_training import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source",
|
||||
[
|
||||
(
|
||||
"post-training/messages",
|
||||
{
|
||||
"type": "uri",
|
||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(360) # 6 minutes timeout
|
||||
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting supervised fine-tuning test")
|
||||
|
||||
# register dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
|
|
@ -39,62 +78,74 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="alpaca",
|
||||
dataset_id=dataset.identifier,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
data_format="instruct",
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type="adamw",
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
# setup training config with minimal settings
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
max_steps_per_epoch=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
post_training_impl = post_training_stack
|
||||
response = await post_training_impl.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="Llama3.2-3B-Instruct",
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HF trl SFTTrainer as the default
|
||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=job_uuid,
|
||||
model="ibm-granite/granite-3.3-2b-instruct",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir="null",
|
||||
checkpoint_dir=None,
|
||||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, list)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("Job not found")
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
logger.info(f"Current status: {status}")
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
||||
logger.info("Waiting for job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"Job artifacts: {artifacts}")
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_jobs(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# jobs_list = await post_training_impl.get_training_jobs()
|
||||
# assert isinstance(jobs_list, list)
|
||||
# assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_status(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
||||
# assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
# assert job_status.job_uuid == "1234"
|
||||
# assert job_status.status == JobStatus.completed
|
||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
# assert job_artifacts.job_uuid == "1234"
|
||||
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
|
||||
# assert job_artifacts.checkpoints[0].epoch == 0
|
||||
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue