mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat: add integration tests for post_training
set inline::huggingface as the default post_training provider for the ollama distribution and add integration tests for post_training Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
7dcb997f17
commit
ff246d890a
10 changed files with 161 additions and 53 deletions
16
.github/workflows/integration-tests.yml
vendored
16
.github/workflows/integration-tests.yml
vendored
|
@ -58,7 +58,7 @@ jobs:
|
|||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
if: matrix.client-type == 'http'
|
||||
|
@ -85,6 +85,11 @@ jobs:
|
|||
echo "Ollama health check failed"
|
||||
exit 1
|
||||
fi
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
free -h
|
||||
df -h
|
||||
|
||||
- name: Run Integration Tests
|
||||
env:
|
||||
|
@ -100,12 +105,19 @@ jobs:
|
|||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||
--embedding-model=all-MiniLM-L6-v2
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
free -h
|
||||
df -h
|
||||
|
||||
- name: Write ollama logs to file
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
sudo journalctl -u ollama.service > ollama.log
|
||||
|
||||
- name: Upload all logs to artifacts
|
||||
if: always()
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
||||
|
|
|
@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::ollama` |
|
||||
| post_training | `inline::huggingface` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
|
|
@ -441,6 +441,7 @@
|
|||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"peft",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
|
@ -451,9 +452,11 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"torch",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"trl",
|
||||
"uvicorn"
|
||||
],
|
||||
"open-benchmark": [
|
||||
|
|
|
@ -23,6 +23,8 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
post_training:
|
||||
- inline::huggingface
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"post_training": ["inline::huggingface"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
|
@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
posttraining_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="ollama",
|
||||
|
@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"post_training": [posttraining_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"post_training": [posttraining_provider],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -80,6 +81,13 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -78,6 +79,13 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -45,6 +45,7 @@ dependencies = [
|
|||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-html",
|
||||
|
|
|
@ -4,20 +4,38 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def capture_output(capsys):
|
||||
"""Fixture to capture and display output during test execution."""
|
||||
yield
|
||||
captured = capsys.readouterr()
|
||||
if captured.out:
|
||||
print("\nCaptured stdout:", captured.out)
|
||||
if captured.err:
|
||||
print("\nCaptured stderr:", captured.err)
|
||||
|
||||
|
||||
# Force flush stdout to see prints immediately
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
|
@ -25,10 +43,31 @@ from llama_stack.apis.post_training import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source",
|
||||
[
|
||||
(
|
||||
"post-training/messages",
|
||||
{
|
||||
"type": "uri",
|
||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(360) # 6 minutes timeout
|
||||
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting supervised fine-tuning test")
|
||||
|
||||
# register dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
|
@ -39,62 +78,74 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="alpaca",
|
||||
dataset_id=dataset.identifier,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
data_format="instruct",
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type="adamw",
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
# setup training config with minimal settings
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
max_steps_per_epoch=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
post_training_impl = post_training_stack
|
||||
response = await post_training_impl.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="Llama3.2-3B-Instruct",
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HF trl SFTTrainer as the default
|
||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid=job_uuid,
|
||||
model="ibm-granite/granite-3.3-2b-instruct",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir="null",
|
||||
checkpoint_dir=None,
|
||||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, list)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("Job not found")
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
logger.info(f"Current status: {status}")
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
||||
logger.info("Waiting for job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"Job artifacts: {artifacts}")
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_jobs(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# jobs_list = await post_training_impl.get_training_jobs()
|
||||
# assert isinstance(jobs_list, list)
|
||||
# assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_status(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
||||
# assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
# assert job_status.job_uuid == "1234"
|
||||
# assert job_status.status == JobStatus.completed
|
||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
# assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
# assert job_artifacts.job_uuid == "1234"
|
||||
# assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
|
||||
# assert job_artifacts.checkpoints[0].epoch == 0
|
||||
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
||||
|
|
14
uv.lock
generated
14
uv.lock
generated
|
@ -1459,6 +1459,7 @@ dev = [
|
|||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-html" },
|
||||
{ name = "pytest-json-report" },
|
||||
{ name = "pytest-timeout" },
|
||||
{ name = "ruamel-yaml" },
|
||||
{ name = "ruff" },
|
||||
{ name = "types-requests" },
|
||||
|
@ -1557,6 +1558,7 @@ requires-dist = [
|
|||
{ name = "pytest-cov", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-html", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-json-report", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-timeout", marker = "extra == 'dev'" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "qdrant-client", marker = "extra == 'unit'" },
|
||||
{ name = "requests" },
|
||||
|
@ -2852,6 +2854,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-timeout"
|
||||
version = "2.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue