mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
add unit test
This commit is contained in:
parent
9c80a57667
commit
5a628d32c8
7 changed files with 188 additions and 1 deletions
|
@ -156,4 +156,5 @@ pytest_plugins = [
|
|||
"llama_stack.providers.tests.datasetio.fixtures",
|
||||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
"llama_stack.providers.tests.eval.fixtures",
|
||||
"llama_stack.providers.tests.post_training.fixtures",
|
||||
]
|
||||
|
|
|
@ -10,6 +10,7 @@ import pytest_asyncio
|
|||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
|
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||
|
||||
from .fixtures import POST_TRAINING_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"post_training": "torchtune",
|
||||
"datasetio": "huggingface",
|
||||
},
|
||||
id="torchtune_post_training_huggingface_datasetio",
|
||||
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "post_training_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"eval": POST_TRAINING_FIXTURES,
|
||||
"datasetio": DATASETIO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("post_training_stack", combinations, indirect=True)
|
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def post_training_torchtune() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="torchtune",
|
||||
provider_type="inline::torchtune",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
POST_TRAINING_FIXTURES = ["torchtune"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def post_training_stack(request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["post_training", "datasetio"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.post_training, Api.datasetio],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
|
||||
datasets=[
|
||||
DatasetInput(
|
||||
dataset_id="alpaca",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
|
||||
metadata={
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"instruction": StringType(),
|
||||
"input": StringType(),
|
||||
"output": StringType(),
|
||||
"text": StringType(),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.post_training]
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
# -m "torchtune_post_training_huggingface_datasetio"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
rank=8,
|
||||
alpha=16,
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="alpaca",
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type="adamw",
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
max_steps_per_epoch=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
post_training_impl = post_training_stack
|
||||
response = await post_training_impl.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="Llama3.2-3B-Instruct",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir="null",
|
||||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
|
@ -16,7 +16,7 @@ providers:
|
|||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
post_training:
|
||||
- provider_id: meta-reference-post-training
|
||||
- provider_id: torchtune-post-training
|
||||
provider_type: inline::torchtune
|
||||
config: {}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue