mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 11:24:20 +00:00
Use correct shapes in unit tests; remove use of unsupported params
This commit is contained in:
parent
26c10b5ab5
commit
bb142435db
3 changed files with 68 additions and 51 deletions
|
|
@ -10,13 +10,17 @@ import warnings
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
OptimizerType,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
|
@ -105,7 +109,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
|
|
@ -116,8 +120,6 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
adapter_dim=16,
|
||||
adapter_dropout=0.1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
|
|
@ -125,10 +127,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
|
|
@ -145,7 +152,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
|
@ -169,16 +176,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
|
|
@ -193,7 +206,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue