From ff246d890a8e78f2945410fdddedf50ec0714893 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 13 May 2025 17:21:30 -0400 Subject: [PATCH] feat: add integration tests for post_training set inline::huggingface as the default post_training provider for the ollama distribution and add integration tests for post_training Signed-off-by: Charlie Doern --- .github/workflows/integration-tests.yml | 16 +- .../self_hosted_distro/ollama.md | 1 + llama_stack/templates/dependencies.json | 3 + llama_stack/templates/ollama/build.yaml | 2 + llama_stack/templates/ollama/ollama.py | 10 +- .../templates/ollama/run-with-safety.yaml | 8 + llama_stack/templates/ollama/run.yaml | 8 + pyproject.toml | 1 + .../post_training/test_post_training.py | 151 ++++++++++++------ uv.lock | 14 ++ 10 files changed, 161 insertions(+), 53 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d755ff0ae..c083da7d9 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -58,7 +58,7 @@ jobs: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | source .venv/bin/activate - nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & + LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & - name: Wait for Llama Stack server to be ready if: matrix.client-type == 'http' @@ -85,6 +85,11 @@ jobs: echo "Ollama health check failed" exit 1 fi + - name: Check Storage and Memory Available Before Tests + if: ${{ always() }} + run: | + free -h + df -h - name: Run Integration Tests env: @@ -100,12 +105,19 @@ jobs: --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 + - name: Check Storage and Memory Available After Tests + if: ${{ always() }} + run: | + free -h + df -h + - name: Write ollama logs to file + if: ${{ always() }} run: | sudo journalctl -u ollama.service > ollama.log - name: Upload all logs to artifacts - if: always() + if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 5d8935fe2..9acbebe6f 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | | inference | `remote::ollama` | +| post_training | `inline::huggingface` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index d1a17e48e..fb4ab9fda 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -441,6 +441,7 @@ "opentelemetry-exporter-otlp-proto-http", "opentelemetry-sdk", "pandas", + "peft", "pillow", "psycopg2-binary", "pymongo", @@ -451,9 +452,11 @@ "scikit-learn", "scipy", "sentencepiece", + "torch", "tqdm", "transformers", "tree_sitter", + "trl", "uvicorn" ], "open-benchmark": [ diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 88e61bf8a..7d5363575 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -23,6 +23,8 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + post_training: + - inline::huggingface tool_runtime: - remote::brave-search - remote::tavily-search diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d72d299ec..0b4f05128 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -28,6 +29,7 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "post_training": ["inline::huggingface"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", @@ -47,7 +49,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::faiss", config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - + posttraining_provider = Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="ollama", @@ -92,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -100,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "post_training": [posttraining_provider], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 651d58117..74d0822ca 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -80,6 +81,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 1372486fe..71229be97 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- post_training - safety - scoring - telemetry @@ -78,6 +79,13 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/pyproject.toml b/pyproject.toml index ba7c2300a..1fe64f350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", + "pytest-timeout", "pytest-asyncio", "pytest-cov", "pytest-html", diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 648ace9d6..bb4639d17 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -4,20 +4,38 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +import sys +import time +import uuid + import pytest -from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( - Checkpoint, DataConfig, LoraFinetuningConfig, - OptimizerConfig, - PostTrainingJob, - PostTrainingJobArtifactsResponse, - PostTrainingJobStatusResponse, TrainingConfig, ) +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + + +@pytest.fixture(autouse=True) +def capture_output(capsys): + """Fixture to capture and display output during test execution.""" + yield + captured = capsys.readouterr() + if captured.out: + print("\nCaptured stdout:", captured.out) + if captured.err: + print("\nCaptured stderr:", captured.err) + + +# Force flush stdout to see prints immediately +sys.stdout.reconfigure(line_buffering=True) + # How to run this test: # # pytest llama_stack/providers/tests/post_training/test_post_training.py @@ -25,10 +43,31 @@ from llama_stack.apis.post_training import ( # -v -s --tb=short --disable-warnings -@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API") class TestPostTraining: - @pytest.mark.asyncio - async def test_supervised_fine_tune(self, post_training_stack): + @pytest.mark.integration + @pytest.mark.parametrize( + "purpose, source", + [ + ( + "post-training/messages", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + ), + ], + ) + @pytest.mark.timeout(360) # 6 minutes timeout + def test_supervised_fine_tune(self, llama_stack_client, purpose, source): + logger.info("Starting supervised fine-tuning test") + + # register dataset to train + dataset = llama_stack_client.datasets.register( + purpose=purpose, + source=source, + ) + logger.info(f"Registered dataset with ID: {dataset.identifier}") + algorithm_config = LoraFinetuningConfig( type="LoRA", lora_attn_modules=["q_proj", "v_proj", "output_proj"], @@ -39,62 +78,74 @@ class TestPostTraining: ) data_config = DataConfig( - dataset_id="alpaca", + dataset_id=dataset.identifier, batch_size=1, shuffle=False, + data_format="instruct", ) - optimizer_config = OptimizerConfig( - optimizer_type="adamw", - lr=3e-4, - lr_min=3e-5, - weight_decay=0.1, - num_warmup_steps=100, - ) - + # setup training config with minimal settings training_config = TrainingConfig( n_epochs=1, data_config=data_config, - optimizer_config=optimizer_config, max_steps_per_epoch=1, gradient_accumulation_steps=1, ) - post_training_impl = post_training_stack - response = await post_training_impl.supervised_fine_tune( - job_uuid="1234", - model="Llama3.2-3B-Instruct", + + job_uuid = f"test-job{uuid.uuid4()}" + logger.info(f"Starting training job with UUID: {job_uuid}") + + # train with HF trl SFTTrainer as the default + _ = llama_stack_client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + model="ibm-granite/granite-3.3-2b-instruct", algorithm_config=algorithm_config, training_config=training_config, hyperparam_search_config={}, logger_config={}, - checkpoint_dir="null", + checkpoint_dir=None, ) - assert isinstance(response, PostTrainingJob) - assert response.job_uuid == "1234" - @pytest.mark.asyncio - async def test_get_training_jobs(self, post_training_stack): - post_training_impl = post_training_stack - jobs_list = await post_training_impl.get_training_jobs() - assert isinstance(jobs_list, list) - assert jobs_list[0].job_uuid == "1234" + while True: + status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) + if not status: + logger.error("Job not found") + break - @pytest.mark.asyncio - async def test_get_training_job_status(self, post_training_stack): - post_training_impl = post_training_stack - job_status = await post_training_impl.get_training_job_status("1234") - assert isinstance(job_status, PostTrainingJobStatusResponse) - assert job_status.job_uuid == "1234" - assert job_status.status == JobStatus.completed - assert isinstance(job_status.checkpoints[0], Checkpoint) + logger.info(f"Current status: {status}") + if status.status == "completed": + break - @pytest.mark.asyncio - async def test_get_training_job_artifacts(self, post_training_stack): - post_training_impl = post_training_stack - job_artifacts = await post_training_impl.get_training_job_artifacts("1234") - assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) - assert job_artifacts.job_uuid == "1234" - assert isinstance(job_artifacts.checkpoints[0], Checkpoint) - assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" - assert job_artifacts.checkpoints[0].epoch == 0 - assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path + logger.info("Waiting for job to complete...") + time.sleep(10) # Increased sleep time to reduce polling frequency + + artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) + logger.info(f"Job artifacts: {artifacts}") + + # TODO: Fix these tests to properly represent the Jobs API in training + # @pytest.mark.asyncio + # async def test_get_training_jobs(self, post_training_stack): + # post_training_impl = post_training_stack + # jobs_list = await post_training_impl.get_training_jobs() + # assert isinstance(jobs_list, list) + # assert jobs_list[0].job_uuid == "1234" + + # @pytest.mark.asyncio + # async def test_get_training_job_status(self, post_training_stack): + # post_training_impl = post_training_stack + # job_status = await post_training_impl.get_training_job_status("1234") + # assert isinstance(job_status, PostTrainingJobStatusResponse) + # assert job_status.job_uuid == "1234" + # assert job_status.status == JobStatus.completed + # assert isinstance(job_status.checkpoints[0], Checkpoint) + + # @pytest.mark.asyncio + # async def test_get_training_job_artifacts(self, post_training_stack): + # post_training_impl = post_training_stack + # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") + # assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse) + # assert job_artifacts.job_uuid == "1234" + # assert isinstance(job_artifacts.checkpoints[0], Checkpoint) + # assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab" + # assert job_artifacts.checkpoints[0].epoch == 0 + # assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path diff --git a/uv.lock b/uv.lock index dbf0c891f..6bd3f84d5 100644 --- a/uv.lock +++ b/uv.lock @@ -1459,6 +1459,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-html" }, { name = "pytest-json-report" }, + { name = "pytest-timeout" }, { name = "ruamel-yaml" }, { name = "ruff" }, { name = "types-requests" }, @@ -1557,6 +1558,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-html", marker = "extra == 'dev'" }, { name = "pytest-json-report", marker = "extra == 'dev'" }, + { name = "pytest-timeout", marker = "extra == 'dev'" }, { name = "python-dotenv" }, { name = "qdrant-client", marker = "extra == 'unit'" }, { name = "requests" }, @@ -2852,6 +2854,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"