removed more redunant code

This commit is contained in:
Ubuntu 2025-07-23 18:05:29 +00:00
parent 41f4678faf
commit 736404c1bd
3 changed files with 215 additions and 202 deletions

View file

@ -4,11 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def setup_environment():
@ -100,7 +115,7 @@ def setup_torch_device(device_str: str) -> torch.device:
return device
async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
@ -112,3 +127,153 @@ async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets.
Args:
ds: Dataset to split
Returns:
tuple: (train_dataset, eval_dataset)
"""
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
"""Calculate training steps and logging configuration.
Args:
steps_per_epoch: Number of training steps per epoch
config: Training configuration
Returns:
dict: Dictionary with calculated step values
"""
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
"""Get save and evaluation strategy based on output directory.
Args:
output_dir_path: Optional path to save the model
Returns:
tuple: (save_strategy, eval_strategy)
"""
if output_dir_path:
logger.info(f"Will save checkpoints to {output_dir_path}")
return "epoch", "epoch"
return "no", "no"
def create_checkpoints(
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
) -> list[Checkpoint]:
"""Create checkpoint objects from training output.
Args:
output_dir_path: Path to the training output directory
job_uuid: Unique identifier for the training job
model: Model identifier
config: Training configuration
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
Returns:
List of Checkpoint objects
"""
checkpoints = []
# Add checkpoint directories
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add final model
final_model_path = output_dir_path / final_model_name
if final_model_path.exists():
training_type = "sft" if final_model_name == "merged_model" else "dpo"
checkpoint = Checkpoint(
identifier=f"{model}-{training_type}-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(final_model_path),
)
checkpoints.append(checkpoint)
return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)