diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8b73500d0..4d7831ae3 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -156,4 +156,5 @@ pytest_plugins = [ "llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.eval.fixtures", + "llama_stack.providers.tests.post_training.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index f0c8cbbe1..d288198ca 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -10,6 +10,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test + from ..conftest import ProviderFixture, remote_stack_fixture diff --git a/llama_stack/providers/tests/post_training/__init__.py b/llama_stack/providers/tests/post_training/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/post_training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/post_training/conftest.py b/llama_stack/providers/tests/post_training/conftest.py new file mode 100644 index 000000000..14d349106 --- /dev/null +++ b/llama_stack/providers/tests/post_training/conftest.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES + +from .fixtures import POST_TRAINING_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "post_training": "torchtune", + "datasetio": "huggingface", + }, + id="torchtune_post_training_huggingface_datasetio", + marks=pytest.mark.torchtune_post_training_huggingface_datasetio, + ), +] + + +def pytest_configure(config): + combined_fixtures = "torchtune_post_training_huggingface_datasetio" + config.addinivalue_line( + "markers", + f"{combined_fixtures}: marks tests as {combined_fixtures} specific", + ) + + +def pytest_generate_tests(metafunc): + if "post_training_stack" in metafunc.fixturenames: + available_fixtures = { + "eval": POST_TRAINING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("post_training_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py new file mode 100644 index 000000000..3ca48d847 --- /dev/null +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasets import DatasetInput +from llama_stack.apis.models import ModelInput + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import construct_stack_for_test + +from ..conftest import ProviderFixture + + +@pytest.fixture(scope="session") +def post_training_torchtune() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="torchtune", + provider_type="inline::torchtune", + config={}, + ) + ], + ) + + +POST_TRAINING_FIXTURES = ["torchtune"] + + +@pytest_asyncio.fixture(scope="session") +async def post_training_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["post_training", "datasetio"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + test_stack = await construct_stack_for_test( + [Api.post_training, Api.datasetio], + providers, + provider_data, + models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")], + datasets=[ + DatasetInput( + dataset_id="alpaca", + provider_id="huggingface", + url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"), + metadata={ + "path": "tatsu-lab/alpaca", + "split": "train", + }, + dataset_schema={ + "instruction": StringType(), + "input": StringType(), + "output": StringType(), + "text": StringType(), + }, + ), + ], + ) + + return test_stack.impls[Api.post_training] diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py new file mode 100644 index 000000000..a4e2d55c9 --- /dev/null +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.post_training import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +# How to run this test: +# +# pytest llama_stack/providers/tests/post_training/test_post_training.py +# -m "torchtune_post_training_huggingface_datasetio" +# -v -s --tb=short --disable-warnings + + +class TestPostTraining: + @pytest.mark.asyncio + async def test_supervised_fine_tune(self, post_training_stack): + algorithm_config = LoraFinetuningConfig( + lora_attn_modules=["q_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + rank=8, + alpha=16, + ) + + data_config = DataConfig( + dataset_id="alpaca", + batch_size=1, + shuffle=False, + ) + + optimizer_config = OptimizerConfig( + optimizer_type="adamw", + lr=3e-4, + lr_min=3e-5, + weight_decay=0.1, + num_warmup_steps=100, + ) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + max_steps_per_epoch=1, + gradient_accumulation_steps=1, + ) + post_training_impl = post_training_stack + response = await post_training_impl.supervised_fine_tune( + job_uuid="1234", + model="Llama3.2-3B-Instruct", + algorithm_config=algorithm_config, + training_config=training_config, + hyperparam_search_config={}, + logger_config={}, + checkpoint_dir="null", + ) + assert isinstance(response, PostTrainingJob) + assert response.job_uuid == "1234" diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 65ec858ce..4bdde7aa6 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -16,7 +16,7 @@ providers: provider_type: inline::meta-reference config: {} post_training: - - provider_id: meta-reference-post-training + - provider_id: torchtune-post-training provider_type: inline::torchtune config: {}