From 346a6c658d8c1674734a52c48cf97cdb282da3bb Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 3 Jan 2025 13:28:39 -0800 Subject: [PATCH] temp commit --- .../apis/post_training/post_training.py | 11 +++- .../post_training/torchtune/common/utils.py | 56 ++++++++++++++++++- .../recipes/lora_finetuning_single_device.py | 15 +++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8e1edbe87..d9dae9e5c 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -27,14 +27,24 @@ class OptimizerType(Enum): sgd = "sgd" +@json_schema_type +class DatasetFormat(Enum): + alpaca = "alpaca" + instruct = "instruct" + chat_sharegpt = "chat_sharegpt" + chat_openai = "chat_openai" + + @json_schema_type class DataConfig(BaseModel): dataset_id: str batch_size: int shuffle: bool + data_format: DatasetFormat validation_dataset_id: Optional[str] = None packed: Optional[bool] = False train_on_input: Optional[bool] = False + column_map: Optional[Dict[str, str]] = None @json_schema_type @@ -58,7 +68,6 @@ class TrainingConfig(BaseModel): n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int - max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig efficiency_config: Optional[EfficiencyConfig] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index a5279cdbe..0c2b663d3 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -11,19 +11,28 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model + from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import DatasetFormat from pydantic import BaseModel +from torchtune.data._messages import ( + AlpacaToMessages, + InputOutputToMessages, + OpenAIToMessages, + ShareGPTToMessages, +) from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b +from torchtune.modules.transforms import Transform class ColumnName(Enum): @@ -31,6 +40,8 @@ class ColumnName(Enum): input = "input" output = "output" text = "text" + conversations = "conversations" + messages = "messages" class ModelConfig(BaseModel): @@ -41,6 +52,9 @@ class ModelConfig(BaseModel): class DatasetSchema(BaseModel): alpaca: List[Dict[str, ParamType]] + instruct: Dict[str, ParamType] + chat_sharegpt: Dict[str, ParamType] + chat_openai: Dict[str, ParamType] MODEL_CONFIGS: Dict[str, ModelConfig] = { @@ -56,6 +70,13 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { ), } +DATA_FORMATS: Dict[str, Transform] = { + "alpaca": AlpacaToMessages, + "instruct": InputOutputToMessages, + "chat_sharegpt": ShareGPTToMessages, + "chat_openai": OpenAIToMessages, +} + EXPECTED_DATASET_SCHEMA = DatasetSchema( alpaca=[ @@ -74,7 +95,17 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema( ColumnName.instruction.value: StringType(), ColumnName.output.value: StringType(), }, - ] + ], + instruct={ + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + }, + chat_sharegpt={ + ColumnName.conversations.value: StringType(), + }, + chat_openai={ + ColumnName.messages.value: StringType(), + }, ) BuildLoraModelCallable = Callable[..., torch.nn.Module] @@ -122,10 +153,15 @@ async def get_checkpointer_model_type( return model_config.checkpoint_type +async def get_data_transform(data_format: DatasetFormat) -> Transform: + return DATA_FORMATS[data_format.value] + + async def validate_input_dataset_schema( datasets_api: Datasets, dataset_id: str, dataset_type: str, + column_map: Optional[Dict[str, str]] = None, ) -> None: dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: @@ -134,7 +170,21 @@ async def validate_input_dataset_schema( if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): raise ValueError(f"Dataset type {dataset_type} is not supported.") - if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): + dataset_schema = {} + + if column_map: + for old_col_name in dataset_def.dataset_schema.keys(): + if old_col_name in column_map.values(): + new_col_name = next( + k for k, v in column_map.items() if v == old_col_name + ) + dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name] + else: + dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name] + else: + dataset_schema = dataset_def.dataset_schema + + if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): raise ValueError( f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index a2ef1c5dd..8e1c65b26 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -42,7 +42,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import AlpacaToMessages, padded_collate_sft +from torchtune.data import padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -129,15 +129,16 @@ class LoraFinetuningSingleDevice: self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 self.total_epochs = training_config.n_epochs + self._data_format = training_config.data_config.data_format self._shuffle = training_config.data_config.shuffle self._batch_size = training_config.data_config.batch_size + self._column_map = training_config.data_config.column_map # this is important for debugging purpose self.max_steps_per_epoch = training_config.max_steps_per_epoch self.global_step = 0 self._gradient_accumulation_steps = training_config.gradient_accumulation_steps - self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( @@ -360,11 +361,15 @@ class LoraFinetuningSingleDevice: await utils.validate_input_dataset_schema( datasets_api=self.datasets_api, dataset_id=dataset_id, - dataset_type="alpaca", + dataset_type=self._data_format.value, + column_map=self._column_map, ) + data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows, - message_transform=AlpacaToMessages(train_on_input=False), + message_transform=data_transform( + train_on_input=False, column_map=self._column_map + ), model_transform=tokenizer, ) @@ -584,7 +589,7 @@ class LoraFinetuningSingleDevice: log.info("Starting validation...") pbar = tqdm(total=len(self._validation_dataloader)) for idx, batch in enumerate(self._validation_dataloader): - if idx == self.max_validation_steps: + if idx == 10: break torchtune_utils.batch_to_device(batch, self._device)