add unit test

This commit is contained in:
Botao Chen 2024-12-05 20:00:28 -08:00
parent 9c80a57667
commit 5a628d32c8
7 changed files with 188 additions and 1 deletions

View file

@ -156,4 +156,5 @@ pytest_plugins = [
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
]

View file

@ -10,6 +10,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from .fixtures import POST_TRAINING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"post_training": "torchtune",
"datasetio": "huggingface",
},
id="torchtune_post_training_huggingface_datasetio",
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
),
]
def pytest_configure(config):
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
config.addinivalue_line(
"markers",
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
)
def pytest_generate_tests(metafunc):
if "post_training_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": POST_TRAINING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("post_training_stack", combinations, indirect=True)

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def post_training_torchtune() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="torchtune",
provider_type="inline::torchtune",
config={},
)
],
)
POST_TRAINING_FIXTURES = ["torchtune"]
@pytest_asyncio.fixture(scope="session")
async def post_training_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["post_training", "datasetio"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.post_training, Api.datasetio],
providers,
provider_data,
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
datasets=[
DatasetInput(
dataset_id="alpaca",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
metadata={
"path": "tatsu-lab/alpaca",
"split": "train",
},
dataset_schema={
"instruction": StringType(),
"input": StringType(),
"output": StringType(),
"text": StringType(),
},
),
],
)
return test_stack.impls[Api.post_training]

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# How to run this test:
#
# pytest llama_stack/providers/tests/post_training/test_post_training.py
# -m "torchtune_post_training_huggingface_datasetio"
# -v -s --tb=short --disable-warnings
class TestPostTraining:
@pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig(
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)
data_config = DataConfig(
dataset_id="alpaca",
batch_size=1,
shuffle=False,
)
optimizer_config = OptimizerConfig(
optimizer_type="adamw",
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=1,
gradient_accumulation_steps=1,
)
post_training_impl = post_training_stack
response = await post_training_impl.supervised_fine_tune(
job_uuid="1234",
model="Llama3.2-3B-Instruct",
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
checkpoint_dir="null",
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"

View file

@ -16,7 +16,7 @@ providers:
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: meta-reference-post-training
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}