From 25c1d9b03766f2485efc19ca7a8bedc877808044 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 14 Jan 2025 12:48:49 -0800 Subject: [PATCH] [post training] define llama stack post training dataset format (#717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## context In this PR, we defined 2 llama stack dataset formats (instruct, dialog) - For instruct dataset format, the column schema will be [chat_completion_input, expected_answer], which is consistent with the eval data format. This dataset format is the abstract of single turn QA style post training data - For dialog dataset format, the column schema will be [dialog], which is a list of user messages and assistant messages that interleave together. During training, the whole list will be the model input and the loss is calculated on assistant messages only. This dataset format is the abstract of multi turn chat style post training data ## changes - defined the 2 llama stack dataset formats - an adapter to convert llama stack dataset format to torchtune dataset format - move dataset format validation to post training level instead of torchtune level since it's not specific to torchtune - add localfs as datasetio provider ## test instruct format - use https://huggingface.co/datasets/llamastack/evals as dataset and the training works as expected Screenshot 2025-01-09 at 5 15 14 PM - use my generated local dataset and the training works as expected Screenshot 2025-01-09 at 5 19 11 PM dialog format - use my generated local dataset and the training works as expected Screenshot 2025-01-09 at 5 23 16 PM --- llama_stack/apis/common/type_system.py | 6 ++ .../apis/post_training/post_training.py | 7 +++ .../inline/post_training/common/__init__.py | 5 ++ .../inline/post_training/common/validator.py | 52 ++++++++++++++++ .../post_training/torchtune/common/utils.py | 60 +++--------------- .../torchtune/datasets/format_adapter.py | 62 +++++++++++++++++++ .../post_training/torchtune/datasets/sft.py | 13 ++++ .../recipes/lora_finetuning_single_device.py | 18 +++--- .../utils/common/data_schema_validator.py | 1 + .../experimental-post-training/build.yaml | 3 + .../experimental-post-training/run.yaml | 30 ++++----- 11 files changed, 182 insertions(+), 75 deletions(-) create mode 100644 llama_stack/providers/inline/post_training/common/__init__.py create mode 100644 llama_stack/providers/inline/post_training/common/validator.py create mode 100644 llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index a653efef9..e76cfde13 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -54,6 +54,12 @@ class AgentTurnInputType(BaseModel): type: Literal["agent_turn_input"] = "agent_turn_input" +class DialogType(BaseModel): + # expects List[Message] for messages + # this type semantically contains the output label whereas ChatCompletionInputType does not + type: Literal["dialog"] = "dialog" + + ParamType = register_schema( Annotated[ Union[ diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8e1edbe87..8841dc1d0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -27,11 +27,18 @@ class OptimizerType(Enum): sgd = "sgd" +@json_schema_type +class DatasetFormat(Enum): + instruct = "instruct" + dialog = "dialog" + + @json_schema_type class DataConfig(BaseModel): dataset_id: str batch_size: int shuffle: bool + data_format: DatasetFormat validation_dataset_id: Optional[str] = None packed: Optional[bool] = False train_on_input: Optional[bool] = False diff --git a/llama_stack/providers/inline/post_training/common/__init__.py b/llama_stack/providers/inline/post_training/common/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py new file mode 100644 index 000000000..836e20c85 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + DialogType, + StringType, +) +from llama_stack.apis.datasets import Datasets +from llama_stack.providers.utils.common.data_schema_validator import ( + ColumnName, + validate_dataset_schema, +) + +EXPECTED_DATASET_SCHEMA = { + "instruct": [ + { + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + ColumnName.expected_answer.value: StringType(), + } + ], + "dialog": [ + { + ColumnName.dialog.value: DialogType(), + } + ], +} + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if dataset_type not in EXPECTED_DATASET_SCHEMA: + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + validate_dataset_schema( + dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type] + ) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index b4cd43770..88011ead4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,29 +10,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model from pydantic import BaseModel +from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_1 import lora_llama3_1_8b from torchtune.models.llama3_2 import lora_llama3_2_3b +from torchtune.modules.transforms import Transform -from llama_stack.apis.common.type_system import ParamType, StringType -from llama_stack.apis.datasets import Datasets - - -class ColumnName(Enum): - instruction = "instruction" - input = "input" - output = "output" - text = "text" +from llama_stack.apis.post_training import DatasetFormat class ModelConfig(BaseModel): @@ -41,10 +34,6 @@ class ModelConfig(BaseModel): checkpoint_type: str -class DatasetSchema(BaseModel): - alpaca: List[Dict[str, ParamType]] - - MODEL_CONFIGS: Dict[str, ModelConfig] = { "Llama3.2-3B-Instruct": ModelConfig( model_definition=lora_llama3_2_3b, @@ -58,26 +47,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { ), } +DATA_FORMATS: Dict[str, Transform] = { + "instruct": InputOutputToMessages, + "dialog": ShareGPTToMessages, +} -EXPECTED_DATASET_SCHEMA = DatasetSchema( - alpaca=[ - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - ColumnName.text.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.output.value: StringType(), - }, - ] -) BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] @@ -124,19 +98,5 @@ async def get_checkpointer_model_type( return model_config.checkpoint_type -async def validate_input_dataset_schema( - datasets_api: Datasets, - dataset_id: str, - dataset_type: str, -) -> None: - dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError(f"Dataset type {dataset_type} is not supported.") - - if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" - ) +async def get_data_transform(data_format: DatasetFormat) -> Transform: + return DATA_FORMATS[data_format.value] diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py new file mode 100644 index 000000000..b4dfbb3c1 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping + +from llama_stack.providers.utils.common.data_schema_validator import ColumnName + + +def llama_stack_instruct_to_torchtune_instruct( + sample: Mapping[str, Any] +) -> Mapping[str, Any]: + assert ( + ColumnName.chat_completion_input.value in sample + and ColumnName.expected_answer.value in sample + ), "Invalid input row" + input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) + + assert ( + len(input_messages) == 1 + ), "llama stack intruct dataset format only supports 1 user message" + input_message = input_messages[0] + + assert "content" in input_message, "content not found in input message" + input = input_message["content"] + output = sample[ColumnName.expected_answer.value] + + return { + "input": input, + "output": output, + } + + +def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]: + assert ColumnName.dialog.value in sample, "Invalid input row" + role_map = {"user": "human", "assistant": "gpt"} + dialog = eval(str(sample[ColumnName.dialog.value])) + + assert len(dialog) > 1, "dialog must have at least 2 messagse" + roles = [] + conversations = [] + for message in dialog: + assert ( + "role" in message and "content" in message + ), "role and content must in message" + roles.append(message["role"]) + conversations.append( + {"from": role_map[message["role"]], "value": message["content"]} + ) + + assert roles[0] == "user", "first message must be from user" + assert "assistant" in roles, "at least 1 message should be from assistant" + + return {"conversations": conversations} diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 1f91dc73f..1a5aade09 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -19,6 +19,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages from torchtune.modules.transforms import Transform +from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import ( + llama_stack_chat_to_torchtune_chat, + llama_stack_instruct_to_torchtune_instruct, +) + class SFTDataset(Dataset): def __init__( @@ -26,10 +31,12 @@ class SFTDataset(Dataset): rows: List[Dict[str, Any]], message_transform: Transform, model_transform: Transform, + dataset_type: str, ) -> None: self._rows = rows self._message_transform = message_transform self._model_transform = model_transform + self._dataset_type = dataset_type def __len__(self): return len(self._rows) @@ -39,6 +46,12 @@ class SFTDataset(Dataset): return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + if self._dataset_type == "instruct": + sample = llama_stack_instruct_to_torchtune_instruct(sample) + elif self._dataset_type == "dialog": + sample = llama_stack_chat_to_torchtune_chat(sample) + else: + raise ValueError(f"Invalid dataset type: {self._dataset_type}") transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 6c795d310..7543b1f4e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -18,7 +18,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import AlpacaToMessages, padded_collate_sft +from torchtune.data import padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -47,6 +47,9 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.inline.post_training.common.validator import ( + validate_input_dataset_schema, +) from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( @@ -129,8 +132,10 @@ class LoraFinetuningSingleDevice: self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 self.total_epochs = training_config.n_epochs + self._data_format = training_config.data_config.data_format self._shuffle = training_config.data_config.shuffle self._batch_size = training_config.data_config.batch_size + self._train_on_input = training_config.data_config.train_on_input # this is important for debugging purpose self.max_steps_per_epoch = training_config.max_steps_per_epoch @@ -354,18 +359,17 @@ class LoraFinetuningSingleDevice: all_rows = await fetch_rows(dataset_id) rows = all_rows.rows - # Curretly only support alpaca instruct dataset - # TODO @SLR722 make the message_transform swappable and support more dataset types - # TODO @SLR722 make the input dataset schema more flexible by exposing column_map - await utils.validate_input_dataset_schema( + await validate_input_dataset_schema( datasets_api=self.datasets_api, dataset_id=dataset_id, - dataset_type="alpaca", + dataset_type=self._data_format.value, ) + data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows, - message_transform=AlpacaToMessages(train_on_input=False), + message_transform=data_transform(train_on_input=self._train_on_input), model_transform=tokenizer, + dataset_type=self._data_format.value, ) sampler = DistributedSampler( diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index af58a4592..55f1078a4 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -23,6 +23,7 @@ class ColumnName(Enum): completion_input = "completion_input" generated_answer = "generated_answer" context = "context" + dialog = "dialog" VALID_SCHEMAS_FOR_SCORING = [ diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml index aa7695bca..e04868199 100644 --- a/llama_stack/templates/experimental-post-training/build.yaml +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -13,6 +13,7 @@ distribution_spec: post_training: - inline::torchtune datasetio: + - inline::localfs - remote::huggingface telemetry: - inline::meta-reference @@ -22,4 +23,6 @@ distribution_spec: - inline::llama-guard memory: - inline::faiss + tool_runtime: + - remote::brave-search image_type: conda diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index a654c375e..4a7bb5c47 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -12,6 +12,7 @@ apis: - scoring - telemetry - post_training +- tool_runtime providers: inference: - provider_id: meta-reference-inference @@ -32,6 +33,9 @@ providers: - provider_id: huggingface-0 provider_type: remote::huggingface config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -60,6 +64,13 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + metadata_store: namespace: null @@ -68,23 +79,6 @@ metadata_store: models: [] shields: [] memory_banks: [] -datasets: - - dataset_id: alpaca - provider_id: huggingface-0 - url: - uri: https://huggingface.co/datasets/tatsu-lab/alpaca - metadata: - path: tatsu-lab/alpaca - name: - split: train - dataset_schema: - instruction: - type: string - input: - type: string - output: - type: string - text: - type: string +datasets: [] scoring_fns: [] eval_tasks: []