mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
temp commit
This commit is contained in:
parent
96d8375663
commit
346a6c658d
3 changed files with 73 additions and 9 deletions
|
@ -27,14 +27,24 @@ class OptimizerType(Enum):
|
||||||
sgd = "sgd"
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DatasetFormat(Enum):
|
||||||
|
alpaca = "alpaca"
|
||||||
|
instruct = "instruct"
|
||||||
|
chat_sharegpt = "chat_sharegpt"
|
||||||
|
chat_openai = "chat_openai"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DataConfig(BaseModel):
|
class DataConfig(BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
batch_size: int
|
batch_size: int
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
|
data_format: DatasetFormat
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: Optional[str] = None
|
||||||
packed: Optional[bool] = False
|
packed: Optional[bool] = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: Optional[bool] = False
|
||||||
|
column_map: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,7 +68,6 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int
|
max_steps_per_epoch: int
|
||||||
gradient_accumulation_steps: int
|
gradient_accumulation_steps: int
|
||||||
max_validation_steps: int
|
|
||||||
data_config: DataConfig
|
data_config: DataConfig
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: OptimizerConfig
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
|
|
|
@ -11,19 +11,28 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.post_training import DatasetFormat
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from torchtune.data._messages import (
|
||||||
|
AlpacaToMessages,
|
||||||
|
InputOutputToMessages,
|
||||||
|
OpenAIToMessages,
|
||||||
|
ShareGPTToMessages,
|
||||||
|
)
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
|
from torchtune.modules.transforms import Transform
|
||||||
|
|
||||||
|
|
||||||
class ColumnName(Enum):
|
class ColumnName(Enum):
|
||||||
|
@ -31,6 +40,8 @@ class ColumnName(Enum):
|
||||||
input = "input"
|
input = "input"
|
||||||
output = "output"
|
output = "output"
|
||||||
text = "text"
|
text = "text"
|
||||||
|
conversations = "conversations"
|
||||||
|
messages = "messages"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
@ -41,6 +52,9 @@ class ModelConfig(BaseModel):
|
||||||
|
|
||||||
class DatasetSchema(BaseModel):
|
class DatasetSchema(BaseModel):
|
||||||
alpaca: List[Dict[str, ParamType]]
|
alpaca: List[Dict[str, ParamType]]
|
||||||
|
instruct: Dict[str, ParamType]
|
||||||
|
chat_sharegpt: Dict[str, ParamType]
|
||||||
|
chat_openai: Dict[str, ParamType]
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
|
@ -56,6 +70,13 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DATA_FORMATS: Dict[str, Transform] = {
|
||||||
|
"alpaca": AlpacaToMessages,
|
||||||
|
"instruct": InputOutputToMessages,
|
||||||
|
"chat_sharegpt": ShareGPTToMessages,
|
||||||
|
"chat_openai": OpenAIToMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||||
alpaca=[
|
alpaca=[
|
||||||
|
@ -74,7 +95,17 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||||
ColumnName.instruction.value: StringType(),
|
ColumnName.instruction.value: StringType(),
|
||||||
ColumnName.output.value: StringType(),
|
ColumnName.output.value: StringType(),
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
|
instruct={
|
||||||
|
ColumnName.input.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
},
|
||||||
|
chat_sharegpt={
|
||||||
|
ColumnName.conversations.value: StringType(),
|
||||||
|
},
|
||||||
|
chat_openai={
|
||||||
|
ColumnName.messages.value: StringType(),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
|
@ -122,10 +153,15 @@ async def get_checkpointer_model_type(
|
||||||
return model_config.checkpoint_type
|
return model_config.checkpoint_type
|
||||||
|
|
||||||
|
|
||||||
|
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
||||||
|
return DATA_FORMATS[data_format.value]
|
||||||
|
|
||||||
|
|
||||||
async def validate_input_dataset_schema(
|
async def validate_input_dataset_schema(
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
dataset_type: str,
|
dataset_type: str,
|
||||||
|
column_map: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
@ -134,7 +170,21 @@ async def validate_input_dataset_schema(
|
||||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||||
|
|
||||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
dataset_schema = {}
|
||||||
|
|
||||||
|
if column_map:
|
||||||
|
for old_col_name in dataset_def.dataset_schema.keys():
|
||||||
|
if old_col_name in column_map.values():
|
||||||
|
new_col_name = next(
|
||||||
|
k for k, v in column_map.items() if v == old_col_name
|
||||||
|
)
|
||||||
|
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||||
|
else:
|
||||||
|
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||||
|
else:
|
||||||
|
dataset_schema = dataset_def.dataset_schema
|
||||||
|
|
||||||
|
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,7 +42,7 @@ from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
from torchtune import modules, training, utils as torchtune_utils
|
||||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
|
@ -129,15 +129,16 @@ class LoraFinetuningSingleDevice:
|
||||||
self.seed = training.set_seed(seed=config.torch_seed)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
self.total_epochs = training_config.n_epochs
|
self.total_epochs = training_config.n_epochs
|
||||||
|
self._data_format = training_config.data_config.data_format
|
||||||
self._shuffle = training_config.data_config.shuffle
|
self._shuffle = training_config.data_config.shuffle
|
||||||
self._batch_size = training_config.data_config.batch_size
|
self._batch_size = training_config.data_config.batch_size
|
||||||
|
self._column_map = training_config.data_config.column_map
|
||||||
|
|
||||||
# this is important for debugging purpose
|
# this is important for debugging purpose
|
||||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||||
self.max_validation_steps = training_config.max_validation_steps
|
|
||||||
|
|
||||||
self._clip_grad_norm = 1.0
|
self._clip_grad_norm = 1.0
|
||||||
self._enable_activation_checkpointing = (
|
self._enable_activation_checkpointing = (
|
||||||
|
@ -360,11 +361,15 @@ class LoraFinetuningSingleDevice:
|
||||||
await utils.validate_input_dataset_schema(
|
await utils.validate_input_dataset_schema(
|
||||||
datasets_api=self.datasets_api,
|
datasets_api=self.datasets_api,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
dataset_type="alpaca",
|
dataset_type=self._data_format.value,
|
||||||
|
column_map=self._column_map,
|
||||||
)
|
)
|
||||||
|
data_transform = await utils.get_data_transform(self._data_format)
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
message_transform=AlpacaToMessages(train_on_input=False),
|
message_transform=data_transform(
|
||||||
|
train_on_input=False, column_map=self._column_map
|
||||||
|
),
|
||||||
model_transform=tokenizer,
|
model_transform=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -584,7 +589,7 @@ class LoraFinetuningSingleDevice:
|
||||||
log.info("Starting validation...")
|
log.info("Starting validation...")
|
||||||
pbar = tqdm(total=len(self._validation_dataloader))
|
pbar = tqdm(total=len(self._validation_dataloader))
|
||||||
for idx, batch in enumerate(self._validation_dataloader):
|
for idx, batch in enumerate(self._validation_dataloader):
|
||||||
if idx == self.max_validation_steps:
|
if idx == 10:
|
||||||
break
|
break
|
||||||
torchtune_utils.batch_to_device(batch, self._device)
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue