[post training] define llama stack post training dataset format (#717)

## context
In this PR, we defined 2 llama stack dataset formats (instruct, dialog)

- For instruct dataset format, the column schema will be
[chat_completion_input, expected_answer], which is consistent with the
eval data format. This dataset format is the abstract of single turn QA
style post training data
- For dialog dataset format, the column schema will be [dialog], which
is a list of user messages and assistant messages that interleave
together. During training, the whole list will be the model input and
the loss is calculated on assistant messages only. This dataset format
is the abstract of multi turn chat style post training data

## changes
- defined the 2 llama stack dataset formats
- an adapter to convert llama stack dataset format to torchtune dataset
format
- move dataset format validation to post training level instead of
torchtune level since it's not specific to torchtune
- add localfs as datasetio provider


## test 
instruct format
- use https://huggingface.co/datasets/llamastack/evals as dataset and
the training works as expected
<img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM"
src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f"
/>

- use my generated local dataset and the training works as expected

<img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM"
src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa"
/>


dialog format
- use my generated local dataset and the training works as expected
<img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM"
src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede"
/>
This commit is contained in:
Botao Chen 2025-01-14 12:48:49 -08:00 committed by GitHub
parent a174938fbd
commit 25c1d9b037
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 182 additions and 75 deletions

View file

@ -54,6 +54,12 @@ class AgentTurnInputType(BaseModel):
type: Literal["agent_turn_input"] = "agent_turn_input" type: Literal["agent_turn_input"] = "agent_turn_input"
class DialogType(BaseModel):
# expects List[Message] for messages
# this type semantically contains the output label whereas ChatCompletionInputType does not
type: Literal["dialog"] = "dialog"
ParamType = register_schema( ParamType = register_schema(
Annotated[ Annotated[
Union[ Union[

View file

@ -27,11 +27,18 @@ class OptimizerType(Enum):
sgd = "sgd" sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
dialog = "dialog"
@json_schema_type @json_schema_type
class DataConfig(BaseModel): class DataConfig(BaseModel):
dataset_id: str dataset_id: str
batch_size: int batch_size: int
shuffle: bool shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False packed: Optional[bool] = False
train_on_input: Optional[bool] = False train_on_input: Optional[bool] = False

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
StringType,
)
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
ColumnName.expected_answer.value: StringType(),
}
],
"dialog": [
{
ColumnName.dialog.value: DialogType(),
}
],
}
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)

View file

@ -10,29 +10,22 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List
import torch import torch
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.post_training import DatasetFormat
from llama_stack.apis.datasets import Datasets
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -41,10 +34,6 @@ class ModelConfig(BaseModel):
checkpoint_type: str checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = { MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig( "Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b, model_definition=lora_llama3_2_3b,
@ -58,26 +47,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
), ),
} }
DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -124,19 +98,5 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type return model_config.checkpoint_type
async def validate_input_dataset_schema( async def get_data_transform(data_format: DatasetFormat) -> Transform:
datasets_api: Datasets, return DATA_FORMATS[data_format.value]
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]
return {
"input": input,
"output": output,
}
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value]))
assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)
assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"
return {"conversations": conversations}

View file

@ -19,6 +19,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
llama_stack_chat_to_torchtune_chat,
llama_stack_instruct_to_torchtune_instruct,
)
class SFTDataset(Dataset): class SFTDataset(Dataset):
def __init__( def __init__(
@ -26,10 +31,12 @@ class SFTDataset(Dataset):
rows: List[Dict[str, Any]], rows: List[Dict[str, Any]],
message_transform: Transform, message_transform: Transform,
model_transform: Transform, model_transform: Transform,
dataset_type: str,
) -> None: ) -> None:
self._rows = rows self._rows = rows
self._message_transform = message_transform self._message_transform = message_transform
self._model_transform = model_transform self._model_transform = model_transform
self._dataset_type = dataset_type
def __len__(self): def __len__(self):
return len(self._rows) return len(self._rows)
@ -39,6 +46,12 @@ class SFTDataset(Dataset):
return self._prepare_sample(sample) return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":
sample = llama_stack_chat_to_torchtune_chat(sample)
else:
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
transformed_sample = self._message_transform(sample) transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample: if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"]) validate_messages(transformed_sample["messages"])

View file

@ -18,7 +18,7 @@ from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
@ -47,6 +47,9 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
@ -129,8 +132,10 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
self.total_epochs = training_config.n_epochs self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size self._batch_size = training_config.data_config.batch_size
self._train_on_input = training_config.data_config.train_on_input
# this is important for debugging purpose # this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch self.max_steps_per_epoch = training_config.max_steps_per_epoch
@ -354,18 +359,17 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows(dataset_id) all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows rows = all_rows.rows
# Curretly only support alpaca instruct dataset await validate_input_dataset_schema(
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api, datasets_api=self.datasets_api,
dataset_id=dataset_id, dataset_id=dataset_id,
dataset_type="alpaca", dataset_type=self._data_format.value,
) )
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset( ds = SFTDataset(
rows, rows,
message_transform=AlpacaToMessages(train_on_input=False), message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer, model_transform=tokenizer,
dataset_type=self._data_format.value,
) )
sampler = DistributedSampler( sampler = DistributedSampler(

View file

@ -23,6 +23,7 @@ class ColumnName(Enum):
completion_input = "completion_input" completion_input = "completion_input"
generated_answer = "generated_answer" generated_answer = "generated_answer"
context = "context" context = "context"
dialog = "dialog"
VALID_SCHEMAS_FOR_SCORING = [ VALID_SCHEMAS_FOR_SCORING = [

View file

@ -13,6 +13,7 @@ distribution_spec:
post_training: post_training:
- inline::torchtune - inline::torchtune
datasetio: datasetio:
- inline::localfs
- remote::huggingface - remote::huggingface
telemetry: telemetry:
- inline::meta-reference - inline::meta-reference
@ -22,4 +23,6 @@ distribution_spec:
- inline::llama-guard - inline::llama-guard
memory: memory:
- inline::faiss - inline::faiss
tool_runtime:
- remote::brave-search
image_type: conda image_type: conda

View file

@ -12,6 +12,7 @@ apis:
- scoring - scoring
- telemetry - telemetry
- post_training - post_training
- tool_runtime
providers: providers:
inference: inference:
- provider_id: meta-reference-inference - provider_id: meta-reference-inference
@ -32,6 +33,9 @@ providers:
- provider_id: huggingface-0 - provider_id: huggingface-0
provider_type: remote::huggingface provider_type: remote::huggingface
config: {} config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
telemetry: telemetry:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -60,6 +64,13 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
metadata_store: metadata_store:
namespace: null namespace: null
@ -68,23 +79,6 @@ metadata_store:
models: [] models: []
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: datasets: []
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []