llama-stack-mirror/llama_stack/providers/inline/post_training/huggingface/config.py
Charlie Doern 6494658a10 feat: add finetune_multi_device recipe with fsdp support
the HF SFTTrainer supports distributed training using FSDP.

Add a new recipe, `finetune_multi_device` which supports multi-GPU (cuda) training
using FSDP and optionally LoRA.

transformers hides _alot_ of their usage of FSDP behind the training args:
a6b51e7341/src/transformers/training_args.py (L1535)

you need to pass both `fsdp` and `fsdp_config` to get it to work properly. However,
it seems many of the `fsdp_config` entries are silently ignored. The key things to get this working were:
full_shard
offload (cpu offload)
transformer_layer_cls_to_wrap (model specific wrapping)
cpu_ram_efficient_loading
sharding_strategy
limit_all_gathers
sync_module_states
backward_prefetch
use_orig_params

these can be seen both in `fsdp=` and `fsdp_config=` int he `SFTConfig` call.

I have tested this with different model architectures with and without LoRA with success.

the user can now toggle `recipe` in their provider config between `single` and `multi` to access the two different recipes.

for debugging purposes NCCL logging settings can now be accessed via the provider config as well

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-06-12 13:33:33 -04:00

83 lines
3.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.00
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
# Recipe type for training (single or multi device)
recipe: str = "single"
# NCCL debug configuration for distributed training
# Enable detailed NCCL logging for debugging distributed training issues
enable_nccl_debug: bool = False
# NCCL subsystems to debug (NONE, ALL, INIT, COLL, P2P, SHM, NET)
# Controls which NCCL components generate debug output
nccl_debug_subsys: str = "NONE"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu", "recipe": "single"}