mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 21:29:57 +00:00
feat: Enable DPO training with HuggingFace inline provider (#2825)
Some checks failed
Integration Tests / discover-tests (push) Has been skipped
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 7s
Integration Tests / record-tests (push) Has been skipped
Integration Tests / run-tests (push) Has been skipped
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 22s
Python Package Build Test / build (3.13) (push) Failing after 16s
Test Llama Stack Build / generate-matrix (push) Successful in 19s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 31s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 32s
Test External API and Providers / test-external (venv) (push) Failing after 32s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 36s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 39s
Update ReadTheDocs / update-readthedocs (push) Failing after 31s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 42s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Failing after 35s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 37s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 40s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 42s
Unit Tests / unit-tests (3.12) (push) Failing after 36s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 40s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 45s
Test Llama Stack Build / build (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m0s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 1m6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 1m8s
Pre-commit / pre-commit (push) Successful in 1m50s
Some checks failed
Integration Tests / discover-tests (push) Has been skipped
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 7s
Integration Tests / record-tests (push) Has been skipped
Integration Tests / run-tests (push) Has been skipped
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 22s
Python Package Build Test / build (3.13) (push) Failing after 16s
Test Llama Stack Build / generate-matrix (push) Successful in 19s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 31s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 32s
Test External API and Providers / test-external (venv) (push) Failing after 32s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 36s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 39s
Update ReadTheDocs / update-readthedocs (push) Failing after 31s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 42s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Failing after 35s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 37s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 40s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 42s
Unit Tests / unit-tests (3.12) (push) Failing after 36s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 40s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 45s
Test Llama Stack Build / build (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 1m1s
Unit Tests / unit-tests (3.13) (push) Failing after 1m0s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 1m6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 1m8s
Pre-commit / pre-commit (push) Successful in 1m50s
What does this PR do? This PR adds support for Direct Preference Optimization (DPO) training via the existing HuggingFace inline provider. It introduces a new DPO training recipe, config schema updates, dataset integration, and end-to-end testing to support preference-based fine-tuning with TRL. Test Plan Added integration test: tests/integration/post_training/test_post_training.py::TestPostTraining::test_preference_optimize Ran tests on both CPU and CUDA environments --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-43-83.ec2.internal> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2665f00102
commit
cf73146132
7 changed files with 913 additions and 215 deletions
269
llama_stack/providers/inline/post_training/huggingface/utils.py
Normal file
269
llama_stack/providers/inline/post_training/huggingface/utils.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""Setup common environment variables for training."""
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
|
||||
def bytes_to_gb(to_convert: int) -> str:
|
||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||
Args:
|
||||
to_convert: Memory value in bytes
|
||||
Returns:
|
||||
str: Memory value in GB formatted to 2 decimal places
|
||||
"""
|
||||
return f"{(to_convert / (1024**3)):.2f}"
|
||||
|
||||
|
||||
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||
"""Get memory statistics for the given device."""
|
||||
stats = {
|
||||
"system_memory": {
|
||||
"total": bytes_to_gb(psutil.virtual_memory().total),
|
||||
"available": bytes_to_gb(psutil.virtual_memory().available),
|
||||
"used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
"percent": psutil.virtual_memory().percent,
|
||||
}
|
||||
}
|
||||
|
||||
if device.type == "cuda":
|
||||
stats["device_memory"] = {
|
||||
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
||||
}
|
||||
elif device.type == "mps":
|
||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||
stats["device_memory"] = {
|
||||
"note": "MPS memory stats not directly available",
|
||||
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
}
|
||||
elif device.type == "cpu":
|
||||
# For CPU, we track process memory usage
|
||||
process = psutil.Process()
|
||||
stats["device_memory"] = {
|
||||
"process_rss": bytes_to_gb(process.memory_info().rss),
|
||||
"process_vms": bytes_to_gb(process.memory_info().vms),
|
||||
"process_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def setup_torch_device(device_str: str) -> torch.device:
|
||||
"""Initialize and validate a PyTorch device.
|
||||
This function handles device initialization and validation for different device types:
|
||||
- CUDA: Validates CUDA availability and handles device selection
|
||||
- MPS: Validates MPS availability for Apple Silicon
|
||||
- CPU: Basic validation
|
||||
- HPU: Raises error as it's not supported
|
||||
Args:
|
||||
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
||||
Returns:
|
||||
torch.device: The initialized and validated device
|
||||
Raises:
|
||||
RuntimeError: If device initialization fails or device is not supported
|
||||
"""
|
||||
try:
|
||||
device = torch.device(device_str)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||
|
||||
# Validate device capabilities
|
||||
if device.type == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||
)
|
||||
if device.index is None:
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
elif device.type == "mps":
|
||||
if not torch.backends.mps.is_available():
|
||||
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||
elif device.type == "hpu":
|
||||
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=-1,
|
||||
)
|
||||
if not isinstance(all_rows.data, list):
|
||||
raise RuntimeError("Expected dataset data to be a list")
|
||||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
|
||||
def load_model(
|
||||
model: str,
|
||||
device: torch.device,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> AutoModelForCausalLM:
|
||||
"""Load and initialize the model for training.
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
device: The device to load the model onto
|
||||
provider_config: Provider-specific configuration
|
||||
Returns:
|
||||
The loaded and initialized model
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
logger.info("Loading the base model")
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||
model,
|
||||
torch_dtype="auto" if device.type != "cpu" else "float32",
|
||||
quantization_config=None,
|
||||
config=model_config,
|
||||
**provider_config.model_specific_config,
|
||||
)
|
||||
# Always move model to specified device
|
||||
model_obj = model_obj.to(device)
|
||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||
return model_obj
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||
|
||||
|
||||
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
|
||||
"""Split dataset into train and validation sets.
|
||||
Args:
|
||||
ds: Dataset to split
|
||||
Returns:
|
||||
tuple: (train_dataset, eval_dataset)
|
||||
"""
|
||||
logger.info("Splitting dataset into train and validation sets")
|
||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = train_val_split["train"]
|
||||
eval_dataset = train_val_split["test"]
|
||||
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""Setup signal handlers for graceful shutdown."""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
|
||||
"""Calculate training steps and logging configuration.
|
||||
Args:
|
||||
steps_per_epoch: Number of training steps per epoch
|
||||
config: Training configuration
|
||||
Returns:
|
||||
dict: Dictionary with calculated step values
|
||||
"""
|
||||
total_steps = steps_per_epoch * config.n_epochs
|
||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||
|
||||
logger.info("Training configuration:")
|
||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||
logger.info(f"- Total steps: {total_steps}")
|
||||
logger.info(f"- Max steps: {max_steps}")
|
||||
logger.info(f"- Logging steps: {logging_steps}")
|
||||
|
||||
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
|
||||
|
||||
|
||||
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
|
||||
"""Get save and evaluation strategy based on output directory.
|
||||
Args:
|
||||
output_dir_path: Optional path to save the model
|
||||
Returns:
|
||||
tuple: (save_strategy, eval_strategy)
|
||||
"""
|
||||
if output_dir_path:
|
||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||
return "epoch", "epoch"
|
||||
return "no", "no"
|
||||
|
||||
|
||||
def create_checkpoints(
|
||||
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
|
||||
) -> list[Checkpoint]:
|
||||
"""Create checkpoint objects from training output.
|
||||
Args:
|
||||
output_dir_path: Path to the training output directory
|
||||
job_uuid: Unique identifier for the training job
|
||||
model: Model identifier
|
||||
config: Training configuration
|
||||
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
|
||||
Returns:
|
||||
List of Checkpoint objects
|
||||
"""
|
||||
checkpoints = []
|
||||
|
||||
# Add checkpoint directories
|
||||
checkpoint_dirs = sorted(
|
||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
||||
key=lambda x: int(x.name.split("-")[1]),
|
||||
)
|
||||
|
||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=checkpoint_dir.name,
|
||||
created_at=created_time,
|
||||
epoch=epoch_number,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(checkpoint_dir),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# Add final model
|
||||
final_model_path = output_dir_path / final_model_name
|
||||
if final_model_path.exists():
|
||||
training_type = "sft" if final_model_name == "merged_model" else "dpo"
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-{training_type}-{config.n_epochs}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(final_model_path),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
Loading…
Add table
Add a link
Reference in a new issue