[1/n] torchtune <> llama-stack integration skeleton (#540)

### Context 
This is the 1st of series PRs that integrate torchtune with llama-stack
as meta reference post-training implementation. For MVP, we will focus
on single device LoRA SFT.

Though this PR is still WIP, we want to get early feedback on the high
level design of this skeleton while still working on several details

### Scope
To limit the scope of this PR, we focus on the skeleton of the
implementation.

**What are included?**
- refine the post-training SFT apis
- skeleton of supervised_fine_tune implementation. We verified that we
can call the supervised_fine_tune API successfully from llama stack
client SDK (client side PR:
https://github.com/meta-llama/llama-stack-client-python/pull/51)
- a very basic single device LoRA training recipe based on torchtune
core components
- parity check with torchtune library and post training api unit test

**What are not includes?**
- implementation of other job management, get training artifacts apis
(separate PR)
- refactor the meta reference inference logic to support eval on
finetuned model (separate PR)
- several necessary functionality in the training recipe such as
logging, validation etc (separate PR)
- interop with telemetry for tracing and metrics logging, currently
temporarily log to local disk (separate PR)

### Testing
**e2e test**
Although we haven't added detailed testing and numerical parity check
with torchtune yet, we did a simple E2E test from client to server
1. setup server with` llama stack build --template
experimental-post-training --image-type conda` and `llama stack run
experimental-post-training `
2. On client, run `llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 post_training
supervised_fine_tune`
3. Training finishes successfully. On server side, get the finetune
checkpoints under output dir. On client side, get the job uuid

server 
<img width="1110" alt="Screenshot 2024-12-02 at 5 52 32 PM"
src="https://github.com/user-attachments/assets/b548eb90-7a9b-4edc-a858-ee237cc4361d">

client 
<img width="807" alt="Screenshot 2024-12-02 at 5 52 37 PM"
src="https://github.com/user-attachments/assets/1138ffa8-4698-40fa-b190-3d7b99646838">

**parity check**
torchtune dataloader output and llama-stack post training dataloader
output are same
<img width="1116" alt="Screenshot 2024-12-04 at 8 18 46 PM"
src="https://github.com/user-attachments/assets/5e295cdc-4c24-4ea6-82c0-ca96ef1bd6ee">

torchtune LoRA SFT and llama-stack post training LoRA SFT on alpaca
dataset with llama3.2 3B instruct model are numerical match

<img width="860" alt="Screenshot 2024-12-04 at 8 17 01 PM"
src="https://github.com/user-attachments/assets/c05cf0a8-c674-4d2e-9f0a-c5d01b2dca99">

<img width="1049" alt="Screenshot 2024-12-04 at 8 17 06 PM"
src="https://github.com/user-attachments/assets/b911d4e2-e7b1-41a9-b62c-d75529b6d443">

**unit test ** 
![Uploading Screenshot 2024-12-09 at 1.35.10 PM.png…]()
This commit is contained in:
Botao Chen 2024-12-13 11:05:35 -08:00 committed by GitHub
parent 53b3a1e345
commit aeb76390fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1172 additions and 68 deletions

View file

@ -6,50 +6,60 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
batch_size: int
shuffle: bool
n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
max_steps_per_epoch: int
gradient_accumulation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@json_schema_type
@ -59,16 +69,19 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
class QATFinetuningConfig(BaseModel):
quantizer_name: str
group_size: int
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
]
@json_schema_type
@ -100,29 +113,6 @@ class DPOAlignmentConfig(BaseModel):
gamma: float
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
@ -135,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel):
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
@ -177,53 +167,49 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
def preference_optimize(
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
def get_training_job_status(
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
}

View file

@ -28,6 +28,7 @@ class Api(Enum):
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
telemetry = "telemetry"

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import TorchtunePostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import TorchtunePostTrainingImpl
impl = TorchtunePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Mapping
import numpy as np
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform
class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
def __len__(self):
return len(self._rows)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
return tokenized_dict

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
class TorchtunePostTrainingImpl:
def __init__(
self,
config: TorchtunePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):
recipe = LoraFinetuningSingleDevice(
self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
await recipe.train()
else:
raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
# TODO @SLR722 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -0,0 +1,506 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
def __init__(
self,
config: TorchtunePostTrainingConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.training_config = training_config
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @SLR722 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_checkpointing)
if training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(training_config.efficiency_config.enable_activation_offloading)
if training_config.efficiency_config
else False
)
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
# Filter files that end with .pth
pth_files = [file for file in files if file.endswith(".pth")]
return pth_files
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
model_type=await utils.get_checkpointer_model_type(self.model_id),
)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=None,
)
log.info(f"Model is initialized with precision {self._dtype}.")
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
self._sampler, self._dataloader = await self._setup_data(
tokenizer=self._tokenizer,
shuffle=self._shuffle,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
async def _setup_model(
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = await utils.get_model_definition(self.model_id)
model = model_type(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
lora_rank=self._lora_rank,
lora_alpha=self._lora_alpha,
quantize_base=False,
use_dora=self._use_dora,
)
self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
async def _setup_tokenizer(
self,
) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path)
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
optimizer = torch.optim.AdamW(
params=self._model.parameters(),
lr=optimizer_config.lr,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
)
return optimizer
async def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
) -> Tuple[DistributedSampler, DataLoader]:
dataset_id = self.training_config.data_config.dataset_id
async def fetch_rows():
return await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
all_rows = await fetch_rows()
rows = all_rows.rows
# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
model_transform=tokenizer,
)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
),
)
return sampler, dataloader
async def _setup_lr_scheduler(
self,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
lr_scheduler = get_cosine_schedule_with_warmup(
self._optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
return loss
async def train(self) -> None:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch)

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List
import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
"Llama-3-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
raise ValueError(f"Model {model_id} is not supported.")
return model
async def get_model_definition(
model_id: str,
) -> BuildLoraModelCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "model_definition"):
raise ValueError(f"Model {model_id} does not have model definition.")
return model_config.model_definition
async def get_tokenizer_type(
model_id: str,
) -> BuildTokenizerCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "tokenizer_type"):
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
return model_config.tokenizer_type
async def get_checkpointer_model_type(
model_id: str,
) -> str:
"""
checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
"""
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "checkpoint_type"):
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
return model_config.checkpoint_type
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]

View file

@ -156,4 +156,5 @@ pytest_plugins = [
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
]

View file

@ -10,6 +10,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from .fixtures import POST_TRAINING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"post_training": "torchtune",
"datasetio": "huggingface",
},
id="torchtune_post_training_huggingface_datasetio",
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
),
]
def pytest_configure(config):
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
config.addinivalue_line(
"markers",
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
)
def pytest_generate_tests(metafunc):
if "post_training_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": POST_TRAINING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("post_training_stack", combinations, indirect=True)

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def post_training_torchtune() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="torchtune",
provider_type="inline::torchtune",
config={},
)
],
)
POST_TRAINING_FIXTURES = ["torchtune"]
@pytest_asyncio.fixture(scope="session")
async def post_training_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["post_training", "datasetio"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.post_training, Api.datasetio],
providers,
provider_data,
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
datasets=[
DatasetInput(
dataset_id="alpaca",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
metadata={
"path": "tatsu-lab/alpaca",
"split": "train",
},
dataset_schema={
"instruction": StringType(),
"input": StringType(),
"output": StringType(),
"text": StringType(),
},
),
],
)
return test_stack.impls[Api.post_training]

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# How to run this test:
#
# pytest llama_stack/providers/tests/post_training/test_post_training.py
# -m "torchtune_post_training_huggingface_datasetio"
# -v -s --tb=short --disable-warnings
class TestPostTraining:
@pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig(
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)
data_config = DataConfig(
dataset_id="alpaca",
batch_size=1,
shuffle=False,
)
optimizer_config = OptimizerConfig(
optimizer_type="adamw",
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=1,
gradient_accumulation_steps=1,
)
post_training_impl = post_training_stack
response = await post_training_impl.supervised_fine_tune(
job_uuid="1234",
model="Llama3.2-3B-Instruct",
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
checkpoint_dir="null",
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"

View file

@ -0,0 +1,13 @@
version: '2'
name: experimental-post-training
distribution_spec:
description: Experimental template for post training
docker_image: null
providers:
post_training:
- inline::torchtune
datasetio:
- remote::huggingface
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,53 @@
version: '2'
image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- telemetry
- datasetio
- post_training
providers:
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.POST_TRAINING_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: []
eval_tasks: []