removed more redunant code

This commit is contained in:
Ubuntu 2025-07-23 18:05:29 +00:00
parent 41f4678faf
commit 736404c1bd
3 changed files with 215 additions and 202 deletions

View file

@ -9,9 +9,6 @@ import json
import logging import logging
import multiprocessing import multiprocessing
import os import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -29,7 +26,6 @@ import torch
from datasets import Dataset from datasets import Dataset
from peft import LoraConfig from peft import LoraConfig
from transformers import ( from transformers import (
AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
) )
@ -45,7 +41,18 @@ from llama_stack.apis.post_training import (
) )
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
from ..utils import get_memory_stats, setup_data, setup_torch_device from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -274,47 +281,10 @@ class HFFinetuningSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset # Split dataset
logger.info("Splitting dataset into train and validation sets") train_dataset, eval_dataset = split_dataset(ds)
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset, tokenizer return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args( def setup_training_args(
self, self,
config: TrainingConfig, config: TrainingConfig,
@ -344,27 +314,12 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training") raise ValueError("DataConfig is required for training")
data_config = config.data_config data_config = config.data_config
# Calculate steps # Calculate steps and get save strategy
total_steps = steps_per_epoch * config.n_epochs step_info = calculate_training_steps(steps_per_epoch, config)
max_steps = min(config.max_steps_per_epoch, total_steps) save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
return SFTConfig( return SFTConfig(
max_steps=max_steps, max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None, output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs, num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size, per_device_train_batch_size=data_config.batch_size,
@ -388,7 +343,7 @@ class HFFinetuningSingleDevice:
load_best_model_at_end=True if output_dir_path else False, load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss", metric_for_best_model="eval_loss",
greater_is_better=False, greater_is_better=False,
logging_steps=logging_steps, logging_steps=step_info["logging_steps"],
) )
def save_model( def save_model(
@ -428,13 +383,8 @@ class HFFinetuningSingleDevice:
) -> None: ) -> None:
"""Run the training process with signal handling.""" """Run the training process with signal handling."""
def signal_handler(signum, frame): # Setup signal handlers
"""Handle termination signals gracefully.""" setup_signal_handlers()
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Convert config dicts back to objects # Convert config dicts back to objects
logger.info("Initializing configuration objects") logger.info("Initializing configuration objects")
@ -463,7 +413,7 @@ class HFFinetuningSingleDevice:
) )
# Load model # Load model
model_obj = self.load_model(model, device, provider_config_obj) model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer # Initialize trainer
logger.info("Initializing SFTTrainer") logger.info("Initializing SFTTrainer")
@ -538,9 +488,8 @@ class HFFinetuningSingleDevice:
# Train in a separate process # Train in a separate process
logger.info("Starting training in separate process") logger.info("Starting training in separate process")
try: try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility # Setup multiprocessing for device
if device.type in ["cuda", "mps"]: setup_multiprocessing_for_device(device)
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process( process = multiprocessing.Process(
target=self._run_training_sync, target=self._run_training_sync,
@ -568,37 +517,7 @@ class HFFinetuningSingleDevice:
checkpoints = [] checkpoints = []
if output_dir_path: if output_dir_path:
# Get all checkpoint directories and sort them numerically checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the merged model as a checkpoint
merged_model_path = output_dir_path / "merged_model"
if merged_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(merged_model_path),
)
checkpoints.append(checkpoint)
return memory_stats, checkpoints if checkpoints else None return memory_stats, checkpoints if checkpoints else None
finally: finally:

View file

@ -8,9 +8,6 @@ import gc
import logging import logging
import multiprocessing import multiprocessing
import os import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -27,8 +24,6 @@ os.environ["MKL_NUM_THREADS"] = "1"
import torch import torch
from datasets import Dataset from datasets import Dataset
from transformers import ( from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
) )
from trl import DPOConfig, DPOTrainer from trl import DPOConfig, DPOTrainer
@ -42,7 +37,18 @@ from llama_stack.apis.post_training import (
) )
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
from ..utils import get_memory_stats, setup_data, setup_torch_device from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -251,38 +257,10 @@ class HFDPOAlignmentSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset # Split dataset
logger.info("Splitting dataset into train and validation sets") train_dataset, eval_dataset = split_dataset(ds)
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset, tokenizer return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for DPO training."""
logger.info("Loading the base model for DPO")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args( def setup_training_args(
self, self,
config: TrainingConfig, config: TrainingConfig,
@ -304,32 +282,19 @@ class HFDPOAlignmentSingleDevice:
raise ValueError("DataConfig is required for training") raise ValueError("DataConfig is required for training")
data_config = config.data_config data_config = config.data_config
# Calculate steps # Calculate steps and get save strategy
total_steps = steps_per_epoch * config.n_epochs step_info = calculate_training_steps(steps_per_epoch, config)
max_steps = min(config.max_steps_per_epoch, total_steps) save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("DPO training configuration:") logger.info("DPO training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
logger.info(f"- DPO beta: {dpo_config.beta}") logger.info(f"- DPO beta: {dpo_config.beta}")
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}") logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
# Calculate max prompt length as half of max sequence length # Calculate max prompt length as half of max sequence length
max_prompt_length = provider_config.max_seq_length // 2 max_prompt_length = provider_config.max_seq_length // 2
return DPOConfig( return DPOConfig(
max_steps=max_steps, max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None, output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs, num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size, per_device_train_batch_size=data_config.batch_size,
@ -352,7 +317,7 @@ class HFDPOAlignmentSingleDevice:
load_best_model_at_end=True if output_dir_path else False, load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss", metric_for_best_model="eval_loss",
greater_is_better=False, greater_is_better=False,
logging_steps=logging_steps, logging_steps=step_info["logging_steps"],
save_total_limit=provider_config.save_total_limit, save_total_limit=provider_config.save_total_limit,
# DPO specific parameters # DPO specific parameters
beta=dpo_config.beta, beta=dpo_config.beta,
@ -383,13 +348,8 @@ class HFDPOAlignmentSingleDevice:
) -> None: ) -> None:
"""Run the DPO training process with signal handling.""" """Run the DPO training process with signal handling."""
def signal_handler(signum, frame): # Setup signal handlers
"""Handle termination signals gracefully.""" setup_signal_handlers()
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Convert config dicts back to objects # Convert config dicts back to objects
logger.info("Initializing configuration objects") logger.info("Initializing configuration objects")
@ -420,11 +380,11 @@ class HFDPOAlignmentSingleDevice:
) )
# Load model and reference model # Load model and reference model
model_obj = self.load_model(model, device, provider_config_obj) model_obj = load_model(model, device, provider_config_obj)
ref_model = None ref_model = None
if provider_config_obj.use_reference_model: if provider_config_obj.use_reference_model:
logger.info("Loading separate reference model for DPO") logger.info("Loading separate reference model for DPO")
ref_model = self.load_model(model, device, provider_config_obj) ref_model = load_model(model, device, provider_config_obj)
else: else:
logger.info("Using shared reference model for DPO") logger.info("Using shared reference model for DPO")
@ -496,9 +456,8 @@ class HFDPOAlignmentSingleDevice:
# Train in a separate process # Train in a separate process
logger.info("Starting DPO training in separate process") logger.info("Starting DPO training in separate process")
try: try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility # Setup multiprocessing for device
if device.type in ["cuda", "mps"]: setup_multiprocessing_for_device(device)
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process( process = multiprocessing.Process(
target=self._run_training_sync, target=self._run_training_sync,
@ -526,37 +485,7 @@ class HFDPOAlignmentSingleDevice:
checkpoints = [] checkpoints = []
if output_dir_path: if output_dir_path:
# Get all checkpoint directories and sort them numerically checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the DPO model as a checkpoint
dpo_model_path = output_dir_path / "dpo_model"
if dpo_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-dpo-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(dpo_model_path),
)
checkpoints.append(checkpoint)
return memory_stats, checkpoints if checkpoints else None return memory_stats, checkpoints if checkpoints else None
finally: finally:

View file

@ -4,11 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import multiprocessing
import os import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any from typing import Any
import psutil import psutil
import torch import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def setup_environment(): def setup_environment():
@ -100,7 +115,7 @@ def setup_torch_device(device_str: str) -> torch.device:
return device return device
async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]: async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider""" """Load dataset from llama stack dataset provider"""
try: try:
all_rows = await datasetio_api.iterrows( all_rows = await datasetio_api.iterrows(
@ -112,3 +127,153 @@ async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
return all_rows.data return all_rows.data
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets.
Args:
ds: Dataset to split
Returns:
tuple: (train_dataset, eval_dataset)
"""
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
"""Calculate training steps and logging configuration.
Args:
steps_per_epoch: Number of training steps per epoch
config: Training configuration
Returns:
dict: Dictionary with calculated step values
"""
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
"""Get save and evaluation strategy based on output directory.
Args:
output_dir_path: Optional path to save the model
Returns:
tuple: (save_strategy, eval_strategy)
"""
if output_dir_path:
logger.info(f"Will save checkpoints to {output_dir_path}")
return "epoch", "epoch"
return "no", "no"
def create_checkpoints(
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
) -> list[Checkpoint]:
"""Create checkpoint objects from training output.
Args:
output_dir_path: Path to the training output directory
job_uuid: Unique identifier for the training job
model: Model identifier
config: Training configuration
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
Returns:
List of Checkpoint objects
"""
checkpoints = []
# Add checkpoint directories
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add final model
final_model_path = output_dir_path / final_model_name
if final_model_path.exists():
training_type = "sft" if final_model_name == "merged_model" else "dpo"
checkpoint = Checkpoint(
identifier=f"{model}-{training_type}-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(final_model_path),
)
checkpoints.append(checkpoint)
return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)