[post training] define llama stack post training dataset format (#717)

## context
In this PR, we defined 2 llama stack dataset formats (instruct, dialog)

- For instruct dataset format, the column schema will be
[chat_completion_input, expected_answer], which is consistent with the
eval data format. This dataset format is the abstract of single turn QA
style post training data
- For dialog dataset format, the column schema will be [dialog], which
is a list of user messages and assistant messages that interleave
together. During training, the whole list will be the model input and
the loss is calculated on assistant messages only. This dataset format
is the abstract of multi turn chat style post training data

## changes
- defined the 2 llama stack dataset formats
- an adapter to convert llama stack dataset format to torchtune dataset
format
- move dataset format validation to post training level instead of
torchtune level since it's not specific to torchtune
- add localfs as datasetio provider


## test 
instruct format
- use https://huggingface.co/datasets/llamastack/evals as dataset and
the training works as expected
<img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM"
src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f"
/>

- use my generated local dataset and the training works as expected

<img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM"
src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa"
/>


dialog format
- use my generated local dataset and the training works as expected
<img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM"
src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede"
/>
This commit is contained in:
Botao Chen 2025-01-14 12:48:49 -08:00 committed by GitHub
parent a174938fbd
commit 25c1d9b037
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 182 additions and 75 deletions

View file

@ -10,29 +10,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
from llama_stack.apis.post_training import DatasetFormat
class ModelConfig(BaseModel):
@ -41,10 +34,6 @@ class ModelConfig(BaseModel):
checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
@ -58,26 +47,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
),
}
DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -124,19 +98,5 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]