mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
temp commit
This commit is contained in:
parent
96d8375663
commit
346a6c658d
3 changed files with 73 additions and 9 deletions
|
@ -27,14 +27,24 @@ class OptimizerType(Enum):
|
|||
sgd = "sgd"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
alpaca = "alpaca"
|
||||
instruct = "instruct"
|
||||
chat_sharegpt = "chat_sharegpt"
|
||||
chat_openai = "chat_openai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
data_format: DatasetFormat
|
||||
validation_dataset_id: Optional[str] = None
|
||||
packed: Optional[bool] = False
|
||||
train_on_input: Optional[bool] = False
|
||||
column_map: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -58,7 +68,6 @@ class TrainingConfig(BaseModel):
|
|||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
|
|
|
@ -11,19 +11,28 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import (
|
||||
AlpacaToMessages,
|
||||
InputOutputToMessages,
|
||||
OpenAIToMessages,
|
||||
ShareGPTToMessages,
|
||||
)
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
|
@ -31,6 +40,8 @@ class ColumnName(Enum):
|
|||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
conversations = "conversations"
|
||||
messages = "messages"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
@ -41,6 +52,9 @@ class ModelConfig(BaseModel):
|
|||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
instruct: Dict[str, ParamType]
|
||||
chat_sharegpt: Dict[str, ParamType]
|
||||
chat_openai: Dict[str, ParamType]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
|
@ -56,6 +70,13 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
"alpaca": AlpacaToMessages,
|
||||
"instruct": InputOutputToMessages,
|
||||
"chat_sharegpt": ShareGPTToMessages,
|
||||
"chat_openai": OpenAIToMessages,
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
|
@ -74,7 +95,17 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
|||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
]
|
||||
],
|
||||
instruct={
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
chat_sharegpt={
|
||||
ColumnName.conversations.value: StringType(),
|
||||
},
|
||||
chat_openai={
|
||||
ColumnName.messages.value: StringType(),
|
||||
},
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
|
@ -122,10 +153,15 @@ async def get_checkpointer_model_type(
|
|||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
||||
return DATA_FORMATS[data_format.value]
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
column_map: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
|
@ -134,7 +170,21 @@ async def validate_input_dataset_schema(
|
|||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
dataset_schema = {}
|
||||
|
||||
if column_map:
|
||||
for old_col_name in dataset_def.dataset_schema.keys():
|
||||
if old_col_name in column_map.values():
|
||||
new_col_name = next(
|
||||
k for k, v in column_map.items() if v == old_col_name
|
||||
)
|
||||
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||
else:
|
||||
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||
else:
|
||||
dataset_schema = dataset_def.dataset_schema
|
||||
|
||||
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
||||
|
|
|
@ -42,7 +42,7 @@ from torch import nn
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
|
@ -129,15 +129,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = training_config.n_epochs
|
||||
self._data_format = training_config.data_config.data_format
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
self._column_map = training_config.data_config.column_map
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
|
@ -360,11 +361,15 @@ class LoraFinetuningSingleDevice:
|
|||
await utils.validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type="alpaca",
|
||||
dataset_type=self._data_format.value,
|
||||
column_map=self._column_map,
|
||||
)
|
||||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
message_transform=data_transform(
|
||||
train_on_input=False, column_map=self._column_map
|
||||
),
|
||||
model_transform=tokenizer,
|
||||
)
|
||||
|
||||
|
@ -584,7 +589,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Starting validation...")
|
||||
pbar = tqdm(total=len(self._validation_dataloader))
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == self.max_validation_steps:
|
||||
if idx == 10:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue