From 736404c1bd327cbdec531016e9ed13a62ba766b8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 18:05:29 +0000 Subject: [PATCH] removed more redunant code --- .../recipes/finetune_single_device.py | 129 +++----------- .../recipes/finetune_single_device_dpo.py | 121 +++---------- .../inline/post_training/huggingface/utils.py | 167 +++++++++++++++++- 3 files changed, 215 insertions(+), 202 deletions(-) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index ffb0ce868..935e5ff15 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -9,9 +9,6 @@ import json import logging import multiprocessing import os -import signal -import sys -from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -29,7 +26,6 @@ import torch from datasets import Dataset from peft import LoraConfig from transformers import ( - AutoConfig, AutoModelForCausalLM, AutoTokenizer, ) @@ -45,7 +41,18 @@ from llama_stack.apis.post_training import ( ) from ..config import HuggingFacePostTrainingConfig -from ..utils import get_memory_stats, setup_data, setup_torch_device +from ..utils import ( + calculate_training_steps, + create_checkpoints, + get_memory_stats, + get_save_strategy, + load_model, + setup_data, + setup_multiprocessing_for_device, + setup_signal_handlers, + setup_torch_device, + split_dataset, +) logger = logging.getLogger(__name__) @@ -274,47 +281,10 @@ class HFFinetuningSingleDevice: raise ValueError(f"Failed to create dataset: {str(e)}") from e # Split dataset - logger.info("Splitting dataset into train and validation sets") - train_val_split = ds.train_test_split(test_size=0.1, seed=42) - train_dataset = train_val_split["train"] - eval_dataset = train_val_split["test"] - logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + train_dataset, eval_dataset = split_dataset(ds) return train_dataset, eval_dataset, tokenizer - def load_model( - self, - model: str, - device: torch.device, - provider_config: HuggingFacePostTrainingConfig, - ) -> AutoModelForCausalLM: - """Load and initialize the model for training. - Args: - model: The model identifier to load - device: The device to load the model onto - provider_config: Provider-specific configuration - Returns: - The loaded and initialized model - Raises: - RuntimeError: If model loading fails - """ - logger.info("Loading the base model") - try: - model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) - model_obj = AutoModelForCausalLM.from_pretrained( - model, - torch_dtype="auto" if device.type != "cpu" else "float32", - quantization_config=None, - config=model_config, - **provider_config.model_specific_config, - ) - # Always move model to specified device - model_obj = model_obj.to(device) - logger.info(f"Model loaded and moved to device: {model_obj.device}") - return model_obj - except Exception as e: - raise RuntimeError(f"Failed to load model: {str(e)}") from e - def setup_training_args( self, config: TrainingConfig, @@ -344,27 +314,12 @@ class HFFinetuningSingleDevice: raise ValueError("DataConfig is required for training") data_config = config.data_config - # Calculate steps - total_steps = steps_per_epoch * config.n_epochs - max_steps = min(config.max_steps_per_epoch, total_steps) - logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch - - logger.info("Training configuration:") - logger.info(f"- Steps per epoch: {steps_per_epoch}") - logger.info(f"- Total steps: {total_steps}") - logger.info(f"- Max steps: {max_steps}") - logger.info(f"- Logging steps: {logging_steps}") - - # Configure save strategy - save_strategy = "no" - eval_strategy = "no" - if output_dir_path: - save_strategy = "epoch" - eval_strategy = "epoch" - logger.info(f"Will save checkpoints to {output_dir_path}") + # Calculate steps and get save strategy + step_info = calculate_training_steps(steps_per_epoch, config) + save_strategy, eval_strategy = get_save_strategy(output_dir_path) return SFTConfig( - max_steps=max_steps, + max_steps=step_info["max_steps"], output_dir=str(output_dir_path) if output_dir_path is not None else None, num_train_epochs=config.n_epochs, per_device_train_batch_size=data_config.batch_size, @@ -388,7 +343,7 @@ class HFFinetuningSingleDevice: load_best_model_at_end=True if output_dir_path else False, metric_for_best_model="eval_loss", greater_is_better=False, - logging_steps=logging_steps, + logging_steps=step_info["logging_steps"], ) def save_model( @@ -428,13 +383,8 @@ class HFFinetuningSingleDevice: ) -> None: """Run the training process with signal handling.""" - def signal_handler(signum, frame): - """Handle termination signals gracefully.""" - logger.info(f"Received signal {signum}, initiating graceful shutdown") - sys.exit(0) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) + # Setup signal handlers + setup_signal_handlers() # Convert config dicts back to objects logger.info("Initializing configuration objects") @@ -463,7 +413,7 @@ class HFFinetuningSingleDevice: ) # Load model - model_obj = self.load_model(model, device, provider_config_obj) + model_obj = load_model(model, device, provider_config_obj) # Initialize trainer logger.info("Initializing SFTTrainer") @@ -538,9 +488,8 @@ class HFFinetuningSingleDevice: # Train in a separate process logger.info("Starting training in separate process") try: - # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility - if device.type in ["cuda", "mps"]: - multiprocessing.set_start_method("spawn", force=True) + # Setup multiprocessing for device + setup_multiprocessing_for_device(device) process = multiprocessing.Process( target=self._run_training_sync, @@ -568,37 +517,7 @@ class HFFinetuningSingleDevice: checkpoints = [] if output_dir_path: - # Get all checkpoint directories and sort them numerically - checkpoint_dirs = sorted( - [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()], - key=lambda x: int(x.name.split("-")[1]), - ) - - # Add all checkpoint directories - for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1): - # Get the creation time of the directory - created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC) - - checkpoint = Checkpoint( - identifier=checkpoint_dir.name, - created_at=created_time, - epoch=epoch_number, - post_training_job_id=job_uuid, - path=str(checkpoint_dir), - ) - checkpoints.append(checkpoint) - - # Add the merged model as a checkpoint - merged_model_path = output_dir_path / "merged_model" - if merged_model_path.exists(): - checkpoint = Checkpoint( - identifier=f"{model}-sft-{config.n_epochs}", - created_at=datetime.now(UTC), - epoch=config.n_epochs, - post_training_job_id=job_uuid, - path=str(merged_model_path), - ) - checkpoints.append(checkpoint) + checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model") return memory_stats, checkpoints if checkpoints else None finally: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index 125bf1b60..25d36afd7 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -8,9 +8,6 @@ import gc import logging import multiprocessing import os -import signal -import sys -from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -27,8 +24,6 @@ os.environ["MKL_NUM_THREADS"] = "1" import torch from datasets import Dataset from transformers import ( - AutoConfig, - AutoModelForCausalLM, AutoTokenizer, ) from trl import DPOConfig, DPOTrainer @@ -42,7 +37,18 @@ from llama_stack.apis.post_training import ( ) from ..config import HuggingFacePostTrainingConfig -from ..utils import get_memory_stats, setup_data, setup_torch_device +from ..utils import ( + calculate_training_steps, + create_checkpoints, + get_memory_stats, + get_save_strategy, + load_model, + setup_data, + setup_multiprocessing_for_device, + setup_signal_handlers, + setup_torch_device, + split_dataset, +) logger = logging.getLogger(__name__) @@ -251,38 +257,10 @@ class HFDPOAlignmentSingleDevice: raise ValueError(f"Failed to create dataset: {str(e)}") from e # Split dataset - logger.info("Splitting dataset into train and validation sets") - train_val_split = ds.train_test_split(test_size=0.1, seed=42) - train_dataset = train_val_split["train"] - eval_dataset = train_val_split["test"] - logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + train_dataset, eval_dataset = split_dataset(ds) return train_dataset, eval_dataset, tokenizer - def load_model( - self, - model: str, - device: torch.device, - provider_config: HuggingFacePostTrainingConfig, - ) -> AutoModelForCausalLM: - """Load and initialize the model for DPO training.""" - logger.info("Loading the base model for DPO") - try: - model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) - model_obj = AutoModelForCausalLM.from_pretrained( - model, - torch_dtype="auto" if device.type != "cpu" else "float32", - quantization_config=None, - config=model_config, - **provider_config.model_specific_config, - ) - # Always move model to specified device - model_obj = model_obj.to(device) - logger.info(f"Model loaded and moved to device: {model_obj.device}") - return model_obj - except Exception as e: - raise RuntimeError(f"Failed to load model: {str(e)}") from e - def setup_training_args( self, config: TrainingConfig, @@ -304,32 +282,19 @@ class HFDPOAlignmentSingleDevice: raise ValueError("DataConfig is required for training") data_config = config.data_config - # Calculate steps - total_steps = steps_per_epoch * config.n_epochs - max_steps = min(config.max_steps_per_epoch, total_steps) - logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + # Calculate steps and get save strategy + step_info = calculate_training_steps(steps_per_epoch, config) + save_strategy, eval_strategy = get_save_strategy(output_dir_path) logger.info("DPO training configuration:") - logger.info(f"- Steps per epoch: {steps_per_epoch}") - logger.info(f"- Total steps: {total_steps}") - logger.info(f"- Max steps: {max_steps}") - logger.info(f"- Logging steps: {logging_steps}") logger.info(f"- DPO beta: {dpo_config.beta}") logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}") - # Configure save strategy - save_strategy = "no" - eval_strategy = "no" - if output_dir_path: - save_strategy = "epoch" - eval_strategy = "epoch" - logger.info(f"Will save checkpoints to {output_dir_path}") - # Calculate max prompt length as half of max sequence length max_prompt_length = provider_config.max_seq_length // 2 return DPOConfig( - max_steps=max_steps, + max_steps=step_info["max_steps"], output_dir=str(output_dir_path) if output_dir_path is not None else None, num_train_epochs=config.n_epochs, per_device_train_batch_size=data_config.batch_size, @@ -352,7 +317,7 @@ class HFDPOAlignmentSingleDevice: load_best_model_at_end=True if output_dir_path else False, metric_for_best_model="eval_loss", greater_is_better=False, - logging_steps=logging_steps, + logging_steps=step_info["logging_steps"], save_total_limit=provider_config.save_total_limit, # DPO specific parameters beta=dpo_config.beta, @@ -383,13 +348,8 @@ class HFDPOAlignmentSingleDevice: ) -> None: """Run the DPO training process with signal handling.""" - def signal_handler(signum, frame): - """Handle termination signals gracefully.""" - logger.info(f"Received signal {signum}, initiating graceful shutdown") - sys.exit(0) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) + # Setup signal handlers + setup_signal_handlers() # Convert config dicts back to objects logger.info("Initializing configuration objects") @@ -420,11 +380,11 @@ class HFDPOAlignmentSingleDevice: ) # Load model and reference model - model_obj = self.load_model(model, device, provider_config_obj) + model_obj = load_model(model, device, provider_config_obj) ref_model = None if provider_config_obj.use_reference_model: logger.info("Loading separate reference model for DPO") - ref_model = self.load_model(model, device, provider_config_obj) + ref_model = load_model(model, device, provider_config_obj) else: logger.info("Using shared reference model for DPO") @@ -496,9 +456,8 @@ class HFDPOAlignmentSingleDevice: # Train in a separate process logger.info("Starting DPO training in separate process") try: - # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility - if device.type in ["cuda", "mps"]: - multiprocessing.set_start_method("spawn", force=True) + # Setup multiprocessing for device + setup_multiprocessing_for_device(device) process = multiprocessing.Process( target=self._run_training_sync, @@ -526,37 +485,7 @@ class HFDPOAlignmentSingleDevice: checkpoints = [] if output_dir_path: - # Get all checkpoint directories and sort them numerically - checkpoint_dirs = sorted( - [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()], - key=lambda x: int(x.name.split("-")[1]), - ) - - # Add all checkpoint directories - for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1): - # Get the creation time of the directory - created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC) - - checkpoint = Checkpoint( - identifier=checkpoint_dir.name, - created_at=created_time, - epoch=epoch_number, - post_training_job_id=job_uuid, - path=str(checkpoint_dir), - ) - checkpoints.append(checkpoint) - - # Add the DPO model as a checkpoint - dpo_model_path = output_dir_path / "dpo_model" - if dpo_model_path.exists(): - checkpoint = Checkpoint( - identifier=f"{model}-dpo-{config.n_epochs}", - created_at=datetime.now(UTC), - epoch=config.n_epochs, - post_training_job_id=job_uuid, - path=str(dpo_model_path), - ) - checkpoints.append(checkpoint) + checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model") return memory_stats, checkpoints if checkpoints else None finally: diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index c93df77c2..cd75ee06b 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -4,11 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +import multiprocessing import os +import signal +import sys +from datetime import UTC, datetime +from pathlib import Path from typing import Any import psutil import torch +from datasets import Dataset +from transformers import AutoConfig, AutoModelForCausalLM + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.post_training import Checkpoint, TrainingConfig + +from .config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) def setup_environment(): @@ -100,7 +115,7 @@ def setup_torch_device(device_str: str) -> torch.device: return device -async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]: +async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: """Load dataset from llama stack dataset provider""" try: all_rows = await datasetio_api.iterrows( @@ -112,3 +127,153 @@ async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]: return all_rows.data except Exception as e: raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + +def load_model( + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, +) -> AutoModelForCausalLM: + """Load and initialize the model for training. + Args: + model: The model identifier to load + device: The device to load the model onto + provider_config: Provider-specific configuration + Returns: + The loaded and initialized model + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto" if device.type != "cpu" else "float32", + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + # Always move model to specified device + model_obj = model_obj.to(device) + logger.info(f"Model loaded and moved to device: {model_obj.device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + +def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]: + """Split dataset into train and validation sets. + Args: + ds: Dataset to split + Returns: + tuple: (train_dataset, eval_dataset) + """ + logger.info("Splitting dataset into train and validation sets") + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + return train_dataset, eval_dataset + + +def setup_signal_handlers(): + """Setup signal handlers for graceful shutdown.""" + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + +def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]: + """Calculate training steps and logging configuration. + Args: + steps_per_epoch: Number of training steps per epoch + config: Training configuration + Returns: + dict: Dictionary with calculated step values + """ + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps} + + +def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]: + """Get save and evaluation strategy based on output directory. + Args: + output_dir_path: Optional path to save the model + Returns: + tuple: (save_strategy, eval_strategy) + """ + if output_dir_path: + logger.info(f"Will save checkpoints to {output_dir_path}") + return "epoch", "epoch" + return "no", "no" + + +def create_checkpoints( + output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str +) -> list[Checkpoint]: + """Create checkpoint objects from training output. + Args: + output_dir_path: Path to the training output directory + job_uuid: Unique identifier for the training job + model: Model identifier + config: Training configuration + final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO) + Returns: + List of Checkpoint objects + """ + checkpoints = [] + + # Add checkpoint directories + checkpoint_dirs = sorted( + [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()], + key=lambda x: int(x.name.split("-")[1]), + ) + + for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1): + created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC) + checkpoint = Checkpoint( + identifier=checkpoint_dir.name, + created_at=created_time, + epoch=epoch_number, + post_training_job_id=job_uuid, + path=str(checkpoint_dir), + ) + checkpoints.append(checkpoint) + + # Add final model + final_model_path = output_dir_path / final_model_name + if final_model_path.exists(): + training_type = "sft" if final_model_name == "merged_model" else "dpo" + checkpoint = Checkpoint( + identifier=f"{model}-{training_type}-{config.n_epochs}", + created_at=datetime.now(UTC), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(final_model_path), + ) + checkpoints.append(checkpoint) + + return checkpoints + + +def setup_multiprocessing_for_device(device: torch.device): + """Setup multiprocessing start method based on device type. + Args: + device: The device being used for training + """ + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True)