This commit is contained in:
Botao Chen 2025-01-09 00:32:18 -08:00
parent f39dcdec9d
commit 49e1a25343
7 changed files with 134 additions and 110 deletions

View file

@ -29,10 +29,8 @@ class OptimizerType(Enum):
@json_schema_type
class DatasetFormat(Enum):
alpaca = "alpaca"
instruct = "instruct"
chat_sharegpt = "chat_sharegpt"
chat_openai = "chat_openai"
chat = "chat"
@json_schema_type
@ -44,7 +42,6 @@ class DataConfig(BaseModel):
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
column_map: Optional[Dict[str, str]] = None
@json_schema_type

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA = {
"instruct": [
{
ColumnName.chat_completion_input.value: StringType(),
ColumnName.expected_answer.value: StringType(),
}
],
"chat": [
{
ColumnName.dialog.value: StringType(),
}
],
}
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)

View file

@ -10,27 +10,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
from pydantic import BaseModel
from torchtune.data._messages import (
AlpacaToMessages,
InputOutputToMessages,
OpenAIToMessages,
ShareGPTToMessages,
)
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
@ -45,13 +34,6 @@ class ModelConfig(BaseModel):
checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
instruct: List[Dict[str, ParamType]]
chat_sharegpt: List[Dict[str, ParamType]]
chat_openai: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
@ -66,49 +48,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
}
DATA_FORMATS: Dict[str, Transform] = {
"alpaca": AlpacaToMessages,
"instruct": InputOutputToMessages,
"chat_sharegpt": ShareGPTToMessages,
"chat_openai": OpenAIToMessages,
"chat": ShareGPTToMessages,
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
],
instruct=[
{
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
}
],
chat_sharegpt=[
{
ColumnName.conversations.value: StringType(),
}
],
chat_openai=[
{
ColumnName.messages.value: StringType(),
}
],
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -156,35 +100,3 @@ async def get_checkpointer_model_type(
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
column_map: Optional[Dict[str, str]] = None,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
dataset_schema = {}
if column_map:
for old_col_name in dataset_def.dataset_schema.keys():
if old_col_name in column_map.values():
new_col_name = next(
k for k, v in column_map.items() if v == old_col_name
)
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema = dataset_def.dataset_schema
validate_dataset_schema(
dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
)

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
assert "content" in input_message, "content not found in input message"
output = sample[ColumnName.expected_answer.value]
return {
"input": input,
"output": output,
}
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value]))
assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)
assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"
return {"conversations": conversations}

View file

@ -13,6 +13,10 @@
from typing import Any, Dict, List, Mapping
import numpy as np
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
llama_stack_chat_to_torchtune_chat,
llama_stack_instruct_to_torchtune_instruct,
)
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
@ -26,10 +30,12 @@ class SFTDataset(Dataset):
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
self._dataset_type = dataset_type
def __len__(self):
return len(self._rows)
@ -39,6 +45,13 @@ class SFTDataset(Dataset):
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "chat":
sample = llama_stack_chat_to_torchtune_chat(sample)
else:
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])

View file

@ -29,6 +29,9 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
@ -132,7 +135,7 @@ class LoraFinetuningSingleDevice:
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._column_map = training_config.data_config.column_map
self._training_on_input = training_config.data_config.training_on_input
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
@ -356,22 +359,17 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type=self._data_format.value,
column_map=self._column_map,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=data_transform(
train_on_input=False, column_map=self._column_map
),
message_transform=data_transform(train_on_input=self._training_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)
sampler = DistributedSampler(

View file

@ -23,12 +23,7 @@ class ColumnName(Enum):
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
dialog = "dialog"
VALID_SCHEMAS_FOR_SCORING = [