feat: add huggingface post_training impl

adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe `finetune_single_device` which works both with and without LoRA.

any model that is a valid HF identifier can be given and the model will be pulled.

this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box

The provider processes the given dataset into the proper format, established the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-11 21:23:59 -04:00
parent 65cf076f13
commit 6c3a40e3d2
5 changed files with 788 additions and 0 deletions

View file

@ -0,0 +1,502 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import json
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import psutil
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
TrainingConfig,
)
from ..config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields."""
required_fields = ["input_query", "expected_answer", "chat_completion_input"]
return all(field in row for row in rows for field in required_fields)
def _process_instruct_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in instruct format."""
if "chat_completion_input" in row and "expected_answer" in row:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
def _process_dialog_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row in dialog format."""
if "dialog" in row:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
role_map = {"user": "human", "assistant": "gpt"}
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
def _process_fallback_format(self, row: dict) -> tuple[str | None, str | None]:
"""Process a row using fallback formats."""
if "input" in row and "output" in row:
return row["input"], row["output"]
elif "prompt" in row and "completion" in row:
return row["prompt"], row["completion"]
elif "question" in row and "answer" in row:
return row["question"], row["answer"]
return None, None
def _format_text(self, input_text: str, output_text: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format input and output text based on model requirements."""
if hasattr(provider_config, "chat_template"):
return provider_config.chat_template.format(input=input_text, output=output_text)
return f"{input_text}\n{output_text}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset."""
formatted_rows = []
for row in rows:
input_text = None
output_text = None
# Process based on format
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
if config.data_config.data_format.value == "instruct":
input_text, output_text = self._process_instruct_format(row)
elif config.data_config.data_format.value == "dialog":
input_text, output_text = self._process_dialog_format(row)
else:
input_text, output_text = self._process_fallback_format(row)
if input_text and output_text:
formatted_text = self._format_text(input_text, output_text, provider_config)
formatted_rows.append({"text": formatted_text})
if not formatted_rows:
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
raise ValueError(
f"No valid input/output pairs found in the dataset for format: {config.data_config.data_format.value}"
)
return Dataset.from_list(formatted_rows)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer."""
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=provider_config.max_seq_length,
return_tensors=None,
)
return ds.map(
tokenize_function,
batched=True,
remove_columns=ds.column_names,
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and preprocess the dataset for training.
Args:
model: The model identifier to load
config: Training configuration containing dataset settings
provider_config: Provider-specific configuration
Returns:
tuple containing:
- Training dataset
- Evaluation dataset
- Tokenizer
Raises:
ValueError: If dataset is missing required fields
RuntimeError: If tokenizer initialization fails
"""
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
rows = await self._setup_data(config.data_config.dataset_id)
# Validate that the dataset has the required fields for training
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
# Initialize tokenizer with model-specific config
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set up tokenizer defaults
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.truncation_side = "right"
tokenizer.model_max_length = provider_config.max_seq_length
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset into train and validation
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
return train_val_split["train"], train_val_split["test"], tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
if model_obj.device != device:
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
lora_config: LoraFinetuningConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer"""
try:
device = torch.device(provider_config.device)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Detect device type and validate
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
# map unqualified 'cuda' to current device
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats throughout training
memory_stats = {
"initial": get_memory_stats(device),
"after_model_load": None,
"after_training": None,
"final": None,
}
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config, provider_config)
# Load model with model-specific config
model_obj = self.load_model(model, device, provider_config)
memory_stats["after_model_load"] = get_memory_stats(device)
# Configure LoRA
peft_config = None
if lora_config:
peft_config = LoraConfig(
lora_alpha=lora_config.alpha,
lora_dropout=0.1,
r=lora_config.rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_config.lora_attn_modules,
)
# Setup training arguments
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
# Calculate steps per epoch and appropriate intervals
steps_per_epoch = len(train_dataset) // config.data_config.batch_size
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info(f"Dataset size: {len(train_dataset)} examples")
logger.info(f"Batch size: {config.data_config.batch_size}")
logger.info(f"Steps per epoch: {steps_per_epoch}")
logger.info(f"Will evaluate every {eval_steps} steps")
logger.info(f"Will save every {save_steps} steps")
logger.info(f"Will log every {logging_steps} steps")
# save_strategy should be none if output dir is none
save_strategy = "no"
if output_dir_path:
save_strategy = "steps"
training_arguments = SFTConfig(
max_steps=config.max_steps_per_epoch,
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=config.data_config.batch_size,
fp16=device.type == "cuda",
bf16=device.type != "cuda",
# use_cpu should only be set if we are on a "True" CPU machine, not a MPS enabled Mac due to stability issues.
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
save_steps=save_steps,
report_to="none",
max_seq_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
logging_steps=logging_steps,
# Enable validation
eval_strategy="steps",
eval_steps=eval_steps,
save_total_limit=provider_config.save_total_limit,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
dataset_text_field="text",
packing=False,
# Add evaluation metrics
# loading the best model can only happen if we have saved a model
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
# Initialize trainer with both train and eval datasets
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_arguments,
)
# Train
logger.info("Starting training")
try:
trainer.train()
memory_stats["after_training"] = get_memory_stats(device)
# Save final model
model_obj.config.use_cache = True
# if we have LoRA we need to do `merge_and_unload`
if lora_config:
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
checkpoint = None
checkpoints = None
# only save a final model if checkpoint dir is specified
# this is especially useful to test training rather than saving of checkpoints
if output_dir_path:
model_obj.save_pretrained(output_dir_path / "merged_model")
# Create checkpoint
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(output_dir_path / "merged_model"),
)
checkpoints = [checkpoint]
return memory_stats, checkpoints
finally:
# Clean up resources
if hasattr(trainer, "model"):
if device.type != "cpu":
trainer.model.to("cpu")
if device.type == "cuda":
torch.cuda.empty_cache()
del trainer.model
del trainer
gc.collect()
memory_stats["final"] = get_memory_stats(device)
async def _setup_data(
self,
dataset_id: str,
) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e