mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. ## Test Plan re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step. [//]: # (## Documentation) --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class HuggingFacePostTrainingConfig(BaseModel):
|
|
# Device to run training on (cuda, cpu, mps)
|
|
device: str = "cuda"
|
|
|
|
# Distributed training backend if using multiple devices
|
|
# fsdp: Fully Sharded Data Parallel
|
|
# deepspeed: DeepSpeed ZeRO optimization
|
|
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
|
|
|
# Format for saving model checkpoints
|
|
# full_state: Save complete model state
|
|
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
|
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
|
|
|
# Template for formatting chat inputs and outputs
|
|
# Used to structure the conversation format for training
|
|
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
|
|
|
# Model-specific configuration parameters
|
|
# trust_remote_code: Allow execution of custom model code
|
|
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
|
model_specific_config: dict = {
|
|
"trust_remote_code": True,
|
|
"attn_implementation": "sdpa",
|
|
}
|
|
|
|
# Maximum sequence length for training
|
|
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
|
# Longer sequences may cause memory issues on MPS devices
|
|
max_seq_length: int = 2048
|
|
|
|
# Enable gradient checkpointing to reduce memory usage
|
|
# Trades computation for memory by recomputing activations
|
|
gradient_checkpointing: bool = False
|
|
|
|
# Maximum number of checkpoints to keep
|
|
# Older checkpoints are deleted when this limit is reached
|
|
save_total_limit: int = 3
|
|
|
|
# Number of training steps between logging updates
|
|
logging_steps: int = 10
|
|
|
|
# Ratio of training steps used for learning rate warmup
|
|
# Helps stabilize early training
|
|
warmup_ratio: float = 0.1
|
|
|
|
# L2 regularization coefficient
|
|
# Helps prevent overfitting
|
|
weight_decay: float = 0.01
|
|
|
|
# Number of worker processes for data loading
|
|
# Higher values can improve data loading speed but increase memory usage
|
|
dataloader_num_workers: int = 4
|
|
|
|
# Whether to pin memory in data loader
|
|
# Can improve data transfer speed to GPU but uses more memory
|
|
dataloader_pin_memory: bool = True
|
|
|
|
@classmethod
|
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
|
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|