This commit is contained in:
Nehanth Narendrula 2025-07-24 20:55:49 -07:00 committed by GitHub
commit afe58dd244
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 930 additions and 207 deletions

View file

@ -24,6 +24,9 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
## Sample Configuration

View file

@ -67,6 +67,11 @@ class HuggingFacePostTrainingConfig(BaseModel):
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
# DPO-specific parameters
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -25,6 +25,9 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
@ -36,6 +39,7 @@ class TrainingArtifactType(Enum):
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
_JOB_TYPE_DPO_TRAINING = "dpo-training"
class HuggingFacePostTrainingImpl:
@ -119,12 +123,40 @@ class HuggingFacePostTrainingImpl:
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF DPO alignment")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
recipe = HFDPOAlignmentSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
# Use default checkpoint directory
output_dir = f"./checkpoints/dpo/{job_uuid}"
resources_allocated, checkpoints = await recipe.train(
model=finetuned_model,
output_dir=output_dir,
job_uuid=job_uuid,
dpo_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
else:
on_log_message_cb("Warning: No checkpoints were saved during DPO training")
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF DPO alignment completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
@ -174,3 +206,9 @@ class HuggingFacePostTrainingImpl:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)

View file

@ -9,14 +9,9 @@ import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
@ -31,7 +26,6 @@ import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
@ -47,91 +41,22 @@ from llama_stack.apis.post_training import (
)
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
class HFFinetuningSingleDevice:
def __init__(
self,
@ -262,19 +187,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,7 +239,7 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
@ -369,47 +281,10 @@ class HFFinetuningSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
@ -439,27 +314,12 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
return SFTConfig(
max_steps=max_steps,
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
@ -483,7 +343,7 @@ class HFFinetuningSingleDevice:
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=logging_steps,
logging_steps=step_info["logging_steps"],
)
def save_model(
@ -523,13 +383,8 @@ class HFFinetuningSingleDevice:
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
@ -558,7 +413,7 @@ class HFFinetuningSingleDevice:
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
@ -633,9 +488,8 @@ class HFFinetuningSingleDevice:
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
process = multiprocessing.Process(
target=self._run_training_sync,
@ -663,37 +517,7 @@ class HFFinetuningSingleDevice:
checkpoints = []
if output_dir_path:
# Get all checkpoint directories and sort them numerically
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the merged model as a checkpoint
merged_model_path = output_dir_path / "merged_model"
if merged_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(merged_model_path),
)
checkpoints.append(checkpoint)
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
return memory_stats, checkpoints if checkpoints else None
finally:

View file

@ -0,0 +1,493 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import multiprocessing
import os
from pathlib import Path
from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
)
from trl import DPOConfig, DPOTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DPOAlignmentConfig,
TrainingConfig,
)
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
class HFDPOAlignmentSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
"""Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"]
if not rows:
logger.warning("Dataset is empty")
return False
for i, row in enumerate(rows):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary")
return False
for field in required_fields:
if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}")
return False
# Handle both string and list formats
if field == "prompt":
# Prompt should be a string
if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string")
return False
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
else:
# chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str):
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
elif isinstance(row[field], list):
if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list")
return False
else:
logger.warning(f"Row {i} field '{field}' is neither string nor list")
return False
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
return True
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats."""
if all(field in row for field in ["prompt", "chosen", "rejected"]):
prompt = row["prompt"]
# Handle chosen field - convert list to string if needed
if isinstance(row["chosen"], list):
# For conversation format, concatenate messages
chosen = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["chosen"]]
)
else:
chosen = row["chosen"]
# Handle rejected field - convert list to string if needed
if isinstance(row["rejected"], list):
# For conversation format, concatenate messages
rejected = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["rejected"]]
)
else:
rejected = row["rejected"]
return prompt, chosen, rejected
return None, None, None
def _format_text_for_dpo(self, prompt: str, response: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format prompt and response text based on model requirements."""
if hasattr(provider_config, "chat_template") and provider_config.chat_template:
# Use the chat template, supporting both {prompt}/{response} and {input}/{output}
template = provider_config.chat_template
# Try prompt/response first (DPO style)
if "{prompt}" in template and "{response}" in template:
return template.format(prompt=prompt, response=response)
# Fall back to input/output (SFT style)
elif "{input}" in template and "{output}" in template:
return template.format(input=prompt, output=response)
else:
# If template doesn't have expected placeholders, use default
return f"{prompt}\n{response}"
return f"{prompt}\n{response}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset for DPO."""
dpo_examples = []
for row in rows:
prompt, chosen, rejected = self._process_dpo_format(row)
if prompt and chosen and rejected:
# Format the texts
chosen_formatted = self._format_text_for_dpo(prompt, chosen, provider_config)
rejected_formatted = self._format_text_for_dpo(prompt, rejected, provider_config)
dpo_examples.append(
{
"prompt": prompt,
"chosen": chosen_formatted,
"rejected": rejected_formatted,
}
)
if not dpo_examples:
raise ValueError("No valid preference examples found in dataset")
logger.info(f"Created DPO dataset with {len(dpo_examples)} preference pairs")
return Dataset.from_list(dpo_examples)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer for DPO."""
# DPOTrainer expects raw text, so we don't tokenize here
# Just return the dataset as is
return ds
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running DPO training process."""
import asyncio
logger.info("Starting DPO training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
dpo_config=dpo_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for DPO training."""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for DPO training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to left for DPO
tokenizer.padding_side = "left"
# Set truncation side to right to keep the beginning of the sequence
tokenizer.truncation_side = "right"
# Set model max length to match provider config
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully for DPO")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset for DPO")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
dpo_config: DPOAlignmentConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> DPOConfig:
"""Setup DPO training arguments."""
logger.info("Configuring DPO training arguments")
lr = 5e-7 # Lower learning rate for DPO
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logger.info("DPO training configuration:")
logger.info(f"- DPO beta: {dpo_config.beta}")
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
# Calculate max prompt length as half of max sequence length
max_prompt_length = provider_config.max_seq_length // 2
return DPOConfig(
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy=eval_strategy,
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_length=provider_config.max_seq_length,
max_prompt_length=max_prompt_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=step_info["logging_steps"],
save_total_limit=provider_config.save_total_limit,
# DPO specific parameters
beta=dpo_config.beta,
loss_type=provider_config.dpo_loss_type,
)
def save_model(
self,
trainer: DPOTrainer,
output_dir_path: Path,
) -> None:
"""Save the trained DPO model."""
logger.info("Saving final DPO model")
save_path = output_dir_path / "dpo_model"
logger.info(f"Saving model to {save_path}")
# Save model and tokenizer
trainer.save_model(str(save_path))
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the DPO training process with signal handling."""
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
dpo_config_obj = DPOAlignmentConfig(**dpo_config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
dpo_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model and reference model
model_obj = load_model(model, device, provider_config_obj)
ref_model = None
if provider_config_obj.use_reference_model:
logger.info("Loading separate reference model for DPO")
ref_model = load_model(model, device, provider_config_obj)
else:
logger.info("Using shared reference model for DPO")
# Initialize DPO trainer
logger.info("Initializing DPOTrainer")
trainer = DPOTrainer(
model=model_obj,
ref_model=ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
try:
# Train
logger.info("Starting DPO training")
trainer.train()
logger.info("DPO training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
logger.info(f"Saving model to output directory: {output_dir_path}")
self.save_model(trainer, output_dir_path)
logger.info("Model save completed")
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
if ref_model:
evacuate_model_from_device(ref_model, device.type)
del trainer
del ref_model
gc.collect()
logger.info("Cleanup completed")
logger.info("DPO training process finishing successfully")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
dpo_config: DPOAlignmentConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's DPOTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting DPO training in separate process")
try:
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"dpo_config": dpo_config.model_dump(),
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"DPO training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = []
if output_dir_path:
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
return memory_stats, checkpoints if checkpoints else None
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -0,0 +1,279 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def setup_environment():
"""Setup common environment variables for training."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets.
Args:
ds: Dataset to split
Returns:
tuple: (train_dataset, eval_dataset)
"""
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
"""Calculate training steps and logging configuration.
Args:
steps_per_epoch: Number of training steps per epoch
config: Training configuration
Returns:
dict: Dictionary with calculated step values
"""
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
"""Get save and evaluation strategy based on output directory.
Args:
output_dir_path: Optional path to save the model
Returns:
tuple: (save_strategy, eval_strategy)
"""
if output_dir_path:
logger.info(f"Will save checkpoints to {output_dir_path}")
return "epoch", "epoch"
return "no", "no"
def create_checkpoints(
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
) -> list[Checkpoint]:
"""Create checkpoint objects from training output.
Args:
output_dir_path: Path to the training output directory
job_uuid: Unique identifier for the training job
model: Model identifier
config: Training configuration
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
Returns:
List of Checkpoint objects
"""
checkpoints = []
# Add checkpoint directories
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add final model
final_model_path = output_dir_path / final_model_name
if final_model_path.exists():
training_type = "sft" if final_model_name == "merged_model" else "dpo"
checkpoint = Checkpoint(
identifier=f"{model}-{training_type}-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(final_model_path),
)
checkpoints.append(checkpoint)
return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)

View file

@ -13,6 +13,9 @@ import pytest
from llama_stack.apis.post_training import (
DataConfig,
DatasetFormat,
DPOAlignmentConfig,
DPOLossType,
LoraFinetuningConfig,
TrainingConfig,
)
@ -43,6 +46,7 @@ sys.stdout.reconfigure(line_buffering=True)
# -v -s --tb=short --disable-warnings
# SFT test
class TestPostTraining:
@pytest.mark.integration
@pytest.mark.parametrize(
@ -81,7 +85,7 @@ class TestPostTraining:
dataset_id=dataset.identifier,
batch_size=1,
shuffle=False,
data_format="instruct",
data_format=DatasetFormat.instruct,
)
# setup training config with minimal settings
@ -122,6 +126,8 @@ class TestPostTraining:
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"Job artifacts: {artifacts}")
logger.info(f"Registered dataset with ID: {dataset.identifier}")
# TODO: Fix these tests to properly represent the Jobs API in training
#
# async def test_get_training_jobs(self, post_training_stack):
@ -149,3 +155,78 @@ class TestPostTraining:
# assert job_artifacts.checkpoints[0].identifier == "instructlab/granite-7b-lab"
# assert job_artifacts.checkpoints[0].epoch == 0
# assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
# DPO test
@pytest.mark.integration
@pytest.mark.parametrize(
"purpose, source",
[
(
"post-training/messages",
{
"type": "uri",
"uri": "huggingface://datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style?split=train[:20]",
},
),
],
)
@pytest.mark.timeout(360)
def test_preference_optimize(self, llama_stack_client, purpose, source):
logger.info("Starting DPO preference optimization test")
# register preference dataset to train
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
logger.info(f"Registered preference dataset with ID: {dataset.identifier}")
# DPO algorithm configuration
algorithm_config = DPOAlignmentConfig(
beta=0.1,
loss_type=DPOLossType.sigmoid,
)
data_config = DataConfig(
dataset_id=dataset.identifier,
batch_size=1,
shuffle=False,
data_format=DatasetFormat.dialog, # DPO datasets often use dialog format
)
# setup training config with minimal settings for DPO
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
max_steps_per_epoch=1, # Just 2 steps for quick testing
gradient_accumulation_steps=1,
)
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
logger.info(f"Starting DPO training job with UUID: {job_uuid}")
# train with HuggingFace DPO implementation
_ = llama_stack_client.post_training.preference_optimize(
job_uuid=job_uuid,
finetuned_model="distilgpt2", # Much smaller model for faster CI testing
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
)
while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
if not status:
logger.error("DPO job not found")
break
logger.info(f"Current DPO status: {status}")
if status.status == "completed":
break
logger.info("Waiting for DPO job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"DPO job artifacts: {artifacts}")