From 49e1a253433acca6eb2e744d49759ce2e69011e8 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 9 Jan 2025 00:32:18 -0800 Subject: [PATCH] refine --- .../apis/post_training/post_training.py | 5 +- .../inline/post_training/common/validator.py | 48 ++++++++++ .../post_training/torchtune/common/utils.py | 94 +------------------ .../torchtune/datasets/format_adapter.py | 61 ++++++++++++ .../post_training/torchtune/datasets/sft.py | 13 +++ .../recipes/lora_finetuning_single_device.py | 16 ++-- .../utils/common/data_schema_validator.py | 7 +- 7 files changed, 134 insertions(+), 110 deletions(-) create mode 100644 llama_stack/providers/inline/post_training/common/validator.py create mode 100644 llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 216157d15..9bcaa24ac 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -29,10 +29,8 @@ class OptimizerType(Enum): @json_schema_type class DatasetFormat(Enum): - alpaca = "alpaca" instruct = "instruct" - chat_sharegpt = "chat_sharegpt" - chat_openai = "chat_openai" + chat = "chat" @json_schema_type @@ -44,7 +42,6 @@ class DataConfig(BaseModel): validation_dataset_id: Optional[str] = None packed: Optional[bool] = False train_on_input: Optional[bool] = False - column_map: Optional[Dict[str, str]] = None @json_schema_type diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py new file mode 100644 index 000000000..2a7f67fd5 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.common.type_system import StringType +from llama_stack.apis.datasets import Datasets +from llama_stack.providers.utils.common.data_schema_validator import ( + ColumnName, + validate_dataset_schema, +) + +EXPECTED_DATASET_SCHEMA = { + "instruct": [ + { + ColumnName.chat_completion_input.value: StringType(), + ColumnName.expected_answer.value: StringType(), + } + ], + "chat": [ + { + ColumnName.dialog.value: StringType(), + } + ], +} + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if dataset_type not in EXPECTED_DATASET_SCHEMA: + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + validate_dataset_schema( + dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type] + ) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 51d60ac3a..198d2f28e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,27 +10,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType, StringType -from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import DatasetFormat -from llama_stack.providers.utils.common.data_schema_validator import ( - ColumnName, - validate_dataset_schema, -) from pydantic import BaseModel -from torchtune.data._messages import ( - AlpacaToMessages, - InputOutputToMessages, - OpenAIToMessages, - ShareGPTToMessages, -) +from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer @@ -45,13 +34,6 @@ class ModelConfig(BaseModel): checkpoint_type: str -class DatasetSchema(BaseModel): - alpaca: List[Dict[str, ParamType]] - instruct: List[Dict[str, ParamType]] - chat_sharegpt: List[Dict[str, ParamType]] - chat_openai: List[Dict[str, ParamType]] - - MODEL_CONFIGS: Dict[str, ModelConfig] = { "Llama3.2-3B-Instruct": ModelConfig( model_definition=lora_llama3_2_3b, @@ -66,49 +48,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { } DATA_FORMATS: Dict[str, Transform] = { - "alpaca": AlpacaToMessages, "instruct": InputOutputToMessages, - "chat_sharegpt": ShareGPTToMessages, - "chat_openai": OpenAIToMessages, + "chat": ShareGPTToMessages, } -EXPECTED_DATASET_SCHEMA = DatasetSchema( - alpaca=[ - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - ColumnName.text.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.output.value: StringType(), - }, - ], - instruct=[ - { - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - } - ], - chat_sharegpt=[ - { - ColumnName.conversations.value: StringType(), - } - ], - chat_openai=[ - { - ColumnName.messages.value: StringType(), - } - ], -) - BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] @@ -156,35 +100,3 @@ async def get_checkpointer_model_type( async def get_data_transform(data_format: DatasetFormat) -> Transform: return DATA_FORMATS[data_format.value] - - -async def validate_input_dataset_schema( - datasets_api: Datasets, - dataset_id: str, - dataset_type: str, - column_map: Optional[Dict[str, str]] = None, -) -> None: - dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError(f"Dataset type {dataset_type} is not supported.") - - dataset_schema = {} - - if column_map: - for old_col_name in dataset_def.dataset_schema.keys(): - if old_col_name in column_map.values(): - new_col_name = next( - k for k, v in column_map.items() if v == old_col_name - ) - dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name] - else: - dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name] - else: - dataset_schema = dataset_def.dataset_schema - - validate_dataset_schema( - dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type) - ) diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py new file mode 100644 index 000000000..d36244464 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping + +from llama_stack.providers.utils.common.data_schema_validator import ColumnName + + +def llama_stack_instruct_to_torchtune_instruct( + sample: Mapping[str, Any] +) -> Mapping[str, Any]: + assert ( + ColumnName.chat_completion_input.value in sample + and ColumnName.expected_answer.value in sample + ), "Invalid input row" + input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) + + assert ( + len(input_messages) == 1 + ), "llama stack intruct dataset format only supports 1 user message" + input_message = input_messages[0] + + assert "content" in input_message, "content not found in input message" + output = sample[ColumnName.expected_answer.value] + + return { + "input": input, + "output": output, + } + + +def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]: + assert ColumnName.dialog.value in sample, "Invalid input row" + role_map = {"user": "human", "assistant": "gpt"} + dialog = eval(str(sample[ColumnName.dialog.value])) + + assert len(dialog) > 1, "dialog must have at least 2 messagse" + roles = [] + conversations = [] + for message in dialog: + assert ( + "role" in message and "content" in message + ), "role and content must in message" + roles.append(message["role"]) + conversations.append( + {"from": role_map[message["role"]], "value": message["content"]} + ) + + assert roles[0] == "user", "first message must be from user" + assert "assistant" in roles, "at least 1 message should be from assistant" + + return {"conversations": conversations} diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 1f91dc73f..5501044c4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -13,6 +13,10 @@ from typing import Any, Dict, List, Mapping import numpy as np +from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import ( + llama_stack_chat_to_torchtune_chat, + llama_stack_instruct_to_torchtune_instruct, +) from torch.utils.data import Dataset from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX @@ -26,10 +30,12 @@ class SFTDataset(Dataset): rows: List[Dict[str, Any]], message_transform: Transform, model_transform: Transform, + dataset_type: str, ) -> None: self._rows = rows self._message_transform = message_transform self._model_transform = model_transform + self._dataset_type = dataset_type def __len__(self): return len(self._rows) @@ -39,6 +45,13 @@ class SFTDataset(Dataset): return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + if self._dataset_type == "instruct": + sample = llama_stack_instruct_to_torchtune_instruct(sample) + elif self._dataset_type == "chat": + sample = llama_stack_chat_to_torchtune_chat(sample) + else: + raise ValueError(f"Invalid dataset type: {self._dataset_type}") + transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index a3649d5ae..95e2bc220 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -29,6 +29,9 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.inline.post_training.common.validator import ( + validate_input_dataset_schema, +) from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( @@ -132,7 +135,7 @@ class LoraFinetuningSingleDevice: self._data_format = training_config.data_config.data_format self._shuffle = training_config.data_config.shuffle self._batch_size = training_config.data_config.batch_size - self._column_map = training_config.data_config.column_map + self._training_on_input = training_config.data_config.training_on_input # this is important for debugging purpose self.max_steps_per_epoch = training_config.max_steps_per_epoch @@ -356,22 +359,17 @@ class LoraFinetuningSingleDevice: all_rows = await fetch_rows(dataset_id) rows = all_rows.rows - # Curretly only support alpaca instruct dataset - # TODO @SLR722 make the message_transform swappable and support more dataset types - # TODO @SLR722 make the input dataset schema more flexible by exposing column_map - await utils.validate_input_dataset_schema( + await validate_input_dataset_schema( datasets_api=self.datasets_api, dataset_id=dataset_id, dataset_type=self._data_format.value, - column_map=self._column_map, ) data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows, - message_transform=data_transform( - train_on_input=False, column_map=self._column_map - ), + message_transform=data_transform(train_on_input=self._training_on_input), model_transform=tokenizer, + dataset_type=self._data_format.value, ) sampler = DistributedSampler( diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index 0322602b7..55f1078a4 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -23,12 +23,7 @@ class ColumnName(Enum): completion_input = "completion_input" generated_answer = "generated_answer" context = "context" - instruction = "instruction" - input = "input" - output = "output" - text = "text" - conversations = "conversations" - messages = "messages" + dialog = "dialog" VALID_SCHEMAS_FOR_SCORING = [