mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
add unit test
This commit is contained in:
parent
9c80a57667
commit
5a628d32c8
7 changed files with 188 additions and 1 deletions
|
@ -156,4 +156,5 @@ pytest_plugins = [
|
||||||
"llama_stack.providers.tests.datasetio.fixtures",
|
"llama_stack.providers.tests.datasetio.fixtures",
|
||||||
"llama_stack.providers.tests.scoring.fixtures",
|
"llama_stack.providers.tests.scoring.fixtures",
|
||||||
"llama_stack.providers.tests.eval.fixtures",
|
"llama_stack.providers.tests.eval.fixtures",
|
||||||
|
"llama_stack.providers.tests.post_training.fixtures",
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,7 @@ import pytest_asyncio
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
|
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||||
|
|
||||||
|
from .fixtures import POST_TRAINING_FIXTURES
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"post_training": "torchtune",
|
||||||
|
"datasetio": "huggingface",
|
||||||
|
},
|
||||||
|
id="torchtune_post_training_huggingface_datasetio",
|
||||||
|
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "post_training_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"eval": POST_TRAINING_FIXTURES,
|
||||||
|
"datasetio": DATASETIO_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("post_training_stack", combinations, indirect=True)
|
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import DatasetInput
|
||||||
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def post_training_torchtune() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="torchtune",
|
||||||
|
provider_type="inline::torchtune",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
POST_TRAINING_FIXTURES = ["torchtune"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def post_training_stack(request):
|
||||||
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["post_training", "datasetio"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
|
test_stack = await construct_stack_for_test(
|
||||||
|
[Api.post_training, Api.datasetio],
|
||||||
|
providers,
|
||||||
|
provider_data,
|
||||||
|
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
|
||||||
|
datasets=[
|
||||||
|
DatasetInput(
|
||||||
|
dataset_id="alpaca",
|
||||||
|
provider_id="huggingface",
|
||||||
|
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
|
||||||
|
metadata={
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
dataset_schema={
|
||||||
|
"instruction": StringType(),
|
||||||
|
"input": StringType(),
|
||||||
|
"output": StringType(),
|
||||||
|
"text": StringType(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_stack.impls[Api.post_training]
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import pytest
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.post_training import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||||
|
# -m "torchtune_post_training_huggingface_datasetio"
|
||||||
|
# -v -s --tb=short --disable-warnings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostTraining:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_supervised_fine_tune(self, post_training_stack):
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
rank=8,
|
||||||
|
alpha=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = DataConfig(
|
||||||
|
dataset_id="alpaca",
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type="adamw",
|
||||||
|
lr=3e-4,
|
||||||
|
lr_min=3e-5,
|
||||||
|
weight_decay=0.1,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=1,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
max_steps_per_epoch=1,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
response = await post_training_impl.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="Llama3.2-3B-Instruct",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
logger_config={},
|
||||||
|
checkpoint_dir="null",
|
||||||
|
)
|
||||||
|
assert isinstance(response, PostTrainingJob)
|
||||||
|
assert response.job_uuid == "1234"
|
|
@ -16,7 +16,7 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: meta-reference-post-training
|
- provider_id: torchtune-post-training
|
||||||
provider_type: inline::torchtune
|
provider_type: inline::torchtune
|
||||||
config: {}
|
config: {}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue