forked from phoenix-oss/llama-stack-mirror
[post training] define llama stack post training dataset format (#717)
## context In this PR, we defined 2 llama stack dataset formats (instruct, dialog) - For instruct dataset format, the column schema will be [chat_completion_input, expected_answer], which is consistent with the eval data format. This dataset format is the abstract of single turn QA style post training data - For dialog dataset format, the column schema will be [dialog], which is a list of user messages and assistant messages that interleave together. During training, the whole list will be the model input and the loss is calculated on assistant messages only. This dataset format is the abstract of multi turn chat style post training data ## changes - defined the 2 llama stack dataset formats - an adapter to convert llama stack dataset format to torchtune dataset format - move dataset format validation to post training level instead of torchtune level since it's not specific to torchtune - add localfs as datasetio provider ## test instruct format - use https://huggingface.co/datasets/llamastack/evals as dataset and the training works as expected <img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM" src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f" /> - use my generated local dataset and the training works as expected <img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM" src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa" /> dialog format - use my generated local dataset and the training works as expected <img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM" src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede" />
This commit is contained in:
parent
a174938fbd
commit
25c1d9b037
11 changed files with 182 additions and 75 deletions
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
assert (
|
||||
ColumnName.chat_completion_input.value in sample
|
||||
and ColumnName.expected_answer.value in sample
|
||||
), "Invalid input row"
|
||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||
|
||||
assert (
|
||||
len(input_messages) == 1
|
||||
), "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
input = input_message["content"]
|
||||
output = sample[ColumnName.expected_answer.value]
|
||||
|
||||
return {
|
||||
"input": input,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
|
||||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
dialog = eval(str(sample[ColumnName.dialog.value]))
|
||||
|
||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||
roles = []
|
||||
conversations = []
|
||||
for message in dialog:
|
||||
assert (
|
||||
"role" in message and "content" in message
|
||||
), "role and content must in message"
|
||||
roles.append(message["role"])
|
||||
conversations.append(
|
||||
{"from": role_map[message["role"]], "value": message["content"]}
|
||||
)
|
||||
|
||||
assert roles[0] == "user", "first message must be from user"
|
||||
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||
|
||||
return {"conversations": conversations}
|
|
@ -19,6 +19,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
|||
from torchtune.data._messages import validate_messages
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
|
||||
llama_stack_chat_to_torchtune_chat,
|
||||
llama_stack_instruct_to_torchtune_instruct,
|
||||
)
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
|
@ -26,10 +31,12 @@ class SFTDataset(Dataset):
|
|||
rows: List[Dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
self._rows = rows
|
||||
self._message_transform = message_transform
|
||||
self._model_transform = model_transform
|
||||
self._dataset_type = dataset_type
|
||||
|
||||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
@ -39,6 +46,12 @@ class SFTDataset(Dataset):
|
|||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "dialog":
|
||||
sample = llama_stack_chat_to_torchtune_chat(sample)
|
||||
else:
|
||||
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
||||
transformed_sample = self._message_transform(sample)
|
||||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue