removed more redunant code

This commit is contained in:
Ubuntu 2025-07-23 18:05:29 +00:00
parent 41f4678faf
commit 736404c1bd
3 changed files with 215 additions and 202 deletions

View file

@ -9,9 +9,6 @@ import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@ -29,7 +26,6 @@ import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
@ -45,7 +41,18 @@ from llama_stack.apis.post_training import (
)
from ..config import HuggingFacePostTrainingConfig
from ..utils import get_memory_stats, setup_data, setup_torch_device
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
@ -274,47 +281,10 @@ class HFFinetuningSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
@ -344,27 +314,12 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
return SFTConfig(
max_steps=max_steps,
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
@ -388,7 +343,7 @@ class HFFinetuningSingleDevice:
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=logging_steps,
logging_steps=step_info["logging_steps"],
)
def save_model(
@ -428,13 +383,8 @@ class HFFinetuningSingleDevice:
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
@ -463,7 +413,7 @@ class HFFinetuningSingleDevice:
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
@ -538,9 +488,8 @@ class HFFinetuningSingleDevice:
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
process = multiprocessing.Process(
target=self._run_training_sync,
@ -568,37 +517,7 @@ class HFFinetuningSingleDevice:
checkpoints = []
if output_dir_path:
# Get all checkpoint directories and sort them numerically
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the merged model as a checkpoint
merged_model_path = output_dir_path / "merged_model"
if merged_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(merged_model_path),
)
checkpoints.append(checkpoint)
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
return memory_stats, checkpoints if checkpoints else None
finally:

View file

@ -8,9 +8,6 @@ import gc
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@ -27,8 +24,6 @@ os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import DPOConfig, DPOTrainer
@ -42,7 +37,18 @@ from llama_stack.apis.post_training import (
)
from ..config import HuggingFacePostTrainingConfig
from ..utils import get_memory_stats, setup_data, setup_torch_device
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
@ -251,38 +257,10 @@ class HFDPOAlignmentSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for DPO training."""
logger.info("Loading the base model for DPO")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
@ -304,32 +282,19 @@ class HFDPOAlignmentSingleDevice:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logger.info("DPO training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
logger.info(f"- DPO beta: {dpo_config.beta}")
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
# Calculate max prompt length as half of max sequence length
max_prompt_length = provider_config.max_seq_length // 2
return DPOConfig(
max_steps=max_steps,
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
@ -352,7 +317,7 @@ class HFDPOAlignmentSingleDevice:
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=logging_steps,
logging_steps=step_info["logging_steps"],
save_total_limit=provider_config.save_total_limit,
# DPO specific parameters
beta=dpo_config.beta,
@ -383,13 +348,8 @@ class HFDPOAlignmentSingleDevice:
) -> None:
"""Run the DPO training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
@ -420,11 +380,11 @@ class HFDPOAlignmentSingleDevice:
)
# Load model and reference model
model_obj = self.load_model(model, device, provider_config_obj)
model_obj = load_model(model, device, provider_config_obj)
ref_model = None
if provider_config_obj.use_reference_model:
logger.info("Loading separate reference model for DPO")
ref_model = self.load_model(model, device, provider_config_obj)
ref_model = load_model(model, device, provider_config_obj)
else:
logger.info("Using shared reference model for DPO")
@ -496,9 +456,8 @@ class HFDPOAlignmentSingleDevice:
# Train in a separate process
logger.info("Starting DPO training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
process = multiprocessing.Process(
target=self._run_training_sync,
@ -526,37 +485,7 @@ class HFDPOAlignmentSingleDevice:
checkpoints = []
if output_dir_path:
# Get all checkpoint directories and sort them numerically
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the DPO model as a checkpoint
dpo_model_path = output_dir_path / "dpo_model"
if dpo_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-dpo-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(dpo_model_path),
)
checkpoints.append(checkpoint)
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
return memory_stats, checkpoints if checkpoints else None
finally: