mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 21:00:01 +00:00
removed more redunant code
This commit is contained in:
parent
41f4678faf
commit
736404c1bd
3 changed files with 215 additions and 202 deletions
|
|
@ -9,9 +9,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -29,7 +26,6 @@ import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
|
|
@ -45,7 +41,18 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import get_memory_stats, setup_data, setup_torch_device
|
from ..utils import (
|
||||||
|
calculate_training_steps,
|
||||||
|
create_checkpoints,
|
||||||
|
get_memory_stats,
|
||||||
|
get_save_strategy,
|
||||||
|
load_model,
|
||||||
|
setup_data,
|
||||||
|
setup_multiprocessing_for_device,
|
||||||
|
setup_signal_handlers,
|
||||||
|
setup_torch_device,
|
||||||
|
split_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -274,47 +281,10 @@ class HFFinetuningSingleDevice:
|
||||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||||
|
|
||||||
# Split dataset
|
# Split dataset
|
||||||
logger.info("Splitting dataset into train and validation sets")
|
train_dataset, eval_dataset = split_dataset(ds)
|
||||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
|
||||||
train_dataset = train_val_split["train"]
|
|
||||||
eval_dataset = train_val_split["test"]
|
|
||||||
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
|
||||||
|
|
||||||
return train_dataset, eval_dataset, tokenizer
|
return train_dataset, eval_dataset, tokenizer
|
||||||
|
|
||||||
def load_model(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
device: torch.device,
|
|
||||||
provider_config: HuggingFacePostTrainingConfig,
|
|
||||||
) -> AutoModelForCausalLM:
|
|
||||||
"""Load and initialize the model for training.
|
|
||||||
Args:
|
|
||||||
model: The model identifier to load
|
|
||||||
device: The device to load the model onto
|
|
||||||
provider_config: Provider-specific configuration
|
|
||||||
Returns:
|
|
||||||
The loaded and initialized model
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If model loading fails
|
|
||||||
"""
|
|
||||||
logger.info("Loading the base model")
|
|
||||||
try:
|
|
||||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
|
||||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model,
|
|
||||||
torch_dtype="auto" if device.type != "cpu" else "float32",
|
|
||||||
quantization_config=None,
|
|
||||||
config=model_config,
|
|
||||||
**provider_config.model_specific_config,
|
|
||||||
)
|
|
||||||
# Always move model to specified device
|
|
||||||
model_obj = model_obj.to(device)
|
|
||||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
|
||||||
return model_obj
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
|
||||||
|
|
||||||
def setup_training_args(
|
def setup_training_args(
|
||||||
self,
|
self,
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
|
|
@ -344,27 +314,12 @@ class HFFinetuningSingleDevice:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise ValueError("DataConfig is required for training")
|
||||||
data_config = config.data_config
|
data_config = config.data_config
|
||||||
|
|
||||||
# Calculate steps
|
# Calculate steps and get save strategy
|
||||||
total_steps = steps_per_epoch * config.n_epochs
|
step_info = calculate_training_steps(steps_per_epoch, config)
|
||||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
|
||||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
|
||||||
|
|
||||||
logger.info("Training configuration:")
|
|
||||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
|
||||||
logger.info(f"- Total steps: {total_steps}")
|
|
||||||
logger.info(f"- Max steps: {max_steps}")
|
|
||||||
logger.info(f"- Logging steps: {logging_steps}")
|
|
||||||
|
|
||||||
# Configure save strategy
|
|
||||||
save_strategy = "no"
|
|
||||||
eval_strategy = "no"
|
|
||||||
if output_dir_path:
|
|
||||||
save_strategy = "epoch"
|
|
||||||
eval_strategy = "epoch"
|
|
||||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
|
||||||
|
|
||||||
return SFTConfig(
|
return SFTConfig(
|
||||||
max_steps=max_steps,
|
max_steps=step_info["max_steps"],
|
||||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||||
num_train_epochs=config.n_epochs,
|
num_train_epochs=config.n_epochs,
|
||||||
per_device_train_batch_size=data_config.batch_size,
|
per_device_train_batch_size=data_config.batch_size,
|
||||||
|
|
@ -388,7 +343,7 @@ class HFFinetuningSingleDevice:
|
||||||
load_best_model_at_end=True if output_dir_path else False,
|
load_best_model_at_end=True if output_dir_path else False,
|
||||||
metric_for_best_model="eval_loss",
|
metric_for_best_model="eval_loss",
|
||||||
greater_is_better=False,
|
greater_is_better=False,
|
||||||
logging_steps=logging_steps,
|
logging_steps=step_info["logging_steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_model(
|
def save_model(
|
||||||
|
|
@ -428,13 +383,8 @@ class HFFinetuningSingleDevice:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the training process with signal handling."""
|
"""Run the training process with signal handling."""
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
# Setup signal handlers
|
||||||
"""Handle termination signals gracefully."""
|
setup_signal_handlers()
|
||||||
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Convert config dicts back to objects
|
# Convert config dicts back to objects
|
||||||
logger.info("Initializing configuration objects")
|
logger.info("Initializing configuration objects")
|
||||||
|
|
@ -463,7 +413,7 @@ class HFFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model_obj = self.load_model(model, device, provider_config_obj)
|
model_obj = load_model(model, device, provider_config_obj)
|
||||||
|
|
||||||
# Initialize trainer
|
# Initialize trainer
|
||||||
logger.info("Initializing SFTTrainer")
|
logger.info("Initializing SFTTrainer")
|
||||||
|
|
@ -538,9 +488,8 @@ class HFFinetuningSingleDevice:
|
||||||
# Train in a separate process
|
# Train in a separate process
|
||||||
logger.info("Starting training in separate process")
|
logger.info("Starting training in separate process")
|
||||||
try:
|
try:
|
||||||
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
|
# Setup multiprocessing for device
|
||||||
if device.type in ["cuda", "mps"]:
|
setup_multiprocessing_for_device(device)
|
||||||
multiprocessing.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
process = multiprocessing.Process(
|
process = multiprocessing.Process(
|
||||||
target=self._run_training_sync,
|
target=self._run_training_sync,
|
||||||
|
|
@ -568,37 +517,7 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
if output_dir_path:
|
if output_dir_path:
|
||||||
# Get all checkpoint directories and sort them numerically
|
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
|
||||||
checkpoint_dirs = sorted(
|
|
||||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
|
||||||
key=lambda x: int(x.name.split("-")[1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add all checkpoint directories
|
|
||||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
|
||||||
# Get the creation time of the directory
|
|
||||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
|
||||||
identifier=checkpoint_dir.name,
|
|
||||||
created_at=created_time,
|
|
||||||
epoch=epoch_number,
|
|
||||||
post_training_job_id=job_uuid,
|
|
||||||
path=str(checkpoint_dir),
|
|
||||||
)
|
|
||||||
checkpoints.append(checkpoint)
|
|
||||||
|
|
||||||
# Add the merged model as a checkpoint
|
|
||||||
merged_model_path = output_dir_path / "merged_model"
|
|
||||||
if merged_model_path.exists():
|
|
||||||
checkpoint = Checkpoint(
|
|
||||||
identifier=f"{model}-sft-{config.n_epochs}",
|
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
epoch=config.n_epochs,
|
|
||||||
post_training_job_id=job_uuid,
|
|
||||||
path=str(merged_model_path),
|
|
||||||
)
|
|
||||||
checkpoints.append(checkpoint)
|
|
||||||
|
|
||||||
return memory_stats, checkpoints if checkpoints else None
|
return memory_stats, checkpoints if checkpoints else None
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,6 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -27,8 +24,6 @@ os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
from trl import DPOConfig, DPOTrainer
|
from trl import DPOConfig, DPOTrainer
|
||||||
|
|
@ -42,7 +37,18 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import get_memory_stats, setup_data, setup_torch_device
|
from ..utils import (
|
||||||
|
calculate_training_steps,
|
||||||
|
create_checkpoints,
|
||||||
|
get_memory_stats,
|
||||||
|
get_save_strategy,
|
||||||
|
load_model,
|
||||||
|
setup_data,
|
||||||
|
setup_multiprocessing_for_device,
|
||||||
|
setup_signal_handlers,
|
||||||
|
setup_torch_device,
|
||||||
|
split_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -251,38 +257,10 @@ class HFDPOAlignmentSingleDevice:
|
||||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||||
|
|
||||||
# Split dataset
|
# Split dataset
|
||||||
logger.info("Splitting dataset into train and validation sets")
|
train_dataset, eval_dataset = split_dataset(ds)
|
||||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
|
||||||
train_dataset = train_val_split["train"]
|
|
||||||
eval_dataset = train_val_split["test"]
|
|
||||||
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
|
||||||
|
|
||||||
return train_dataset, eval_dataset, tokenizer
|
return train_dataset, eval_dataset, tokenizer
|
||||||
|
|
||||||
def load_model(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
device: torch.device,
|
|
||||||
provider_config: HuggingFacePostTrainingConfig,
|
|
||||||
) -> AutoModelForCausalLM:
|
|
||||||
"""Load and initialize the model for DPO training."""
|
|
||||||
logger.info("Loading the base model for DPO")
|
|
||||||
try:
|
|
||||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
|
||||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model,
|
|
||||||
torch_dtype="auto" if device.type != "cpu" else "float32",
|
|
||||||
quantization_config=None,
|
|
||||||
config=model_config,
|
|
||||||
**provider_config.model_specific_config,
|
|
||||||
)
|
|
||||||
# Always move model to specified device
|
|
||||||
model_obj = model_obj.to(device)
|
|
||||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
|
||||||
return model_obj
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
|
||||||
|
|
||||||
def setup_training_args(
|
def setup_training_args(
|
||||||
self,
|
self,
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
|
|
@ -304,32 +282,19 @@ class HFDPOAlignmentSingleDevice:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise ValueError("DataConfig is required for training")
|
||||||
data_config = config.data_config
|
data_config = config.data_config
|
||||||
|
|
||||||
# Calculate steps
|
# Calculate steps and get save strategy
|
||||||
total_steps = steps_per_epoch * config.n_epochs
|
step_info = calculate_training_steps(steps_per_epoch, config)
|
||||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
|
||||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
|
||||||
|
|
||||||
logger.info("DPO training configuration:")
|
logger.info("DPO training configuration:")
|
||||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
|
||||||
logger.info(f"- Total steps: {total_steps}")
|
|
||||||
logger.info(f"- Max steps: {max_steps}")
|
|
||||||
logger.info(f"- Logging steps: {logging_steps}")
|
|
||||||
logger.info(f"- DPO beta: {dpo_config.beta}")
|
logger.info(f"- DPO beta: {dpo_config.beta}")
|
||||||
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
|
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
|
||||||
|
|
||||||
# Configure save strategy
|
|
||||||
save_strategy = "no"
|
|
||||||
eval_strategy = "no"
|
|
||||||
if output_dir_path:
|
|
||||||
save_strategy = "epoch"
|
|
||||||
eval_strategy = "epoch"
|
|
||||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
|
||||||
|
|
||||||
# Calculate max prompt length as half of max sequence length
|
# Calculate max prompt length as half of max sequence length
|
||||||
max_prompt_length = provider_config.max_seq_length // 2
|
max_prompt_length = provider_config.max_seq_length // 2
|
||||||
|
|
||||||
return DPOConfig(
|
return DPOConfig(
|
||||||
max_steps=max_steps,
|
max_steps=step_info["max_steps"],
|
||||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||||
num_train_epochs=config.n_epochs,
|
num_train_epochs=config.n_epochs,
|
||||||
per_device_train_batch_size=data_config.batch_size,
|
per_device_train_batch_size=data_config.batch_size,
|
||||||
|
|
@ -352,7 +317,7 @@ class HFDPOAlignmentSingleDevice:
|
||||||
load_best_model_at_end=True if output_dir_path else False,
|
load_best_model_at_end=True if output_dir_path else False,
|
||||||
metric_for_best_model="eval_loss",
|
metric_for_best_model="eval_loss",
|
||||||
greater_is_better=False,
|
greater_is_better=False,
|
||||||
logging_steps=logging_steps,
|
logging_steps=step_info["logging_steps"],
|
||||||
save_total_limit=provider_config.save_total_limit,
|
save_total_limit=provider_config.save_total_limit,
|
||||||
# DPO specific parameters
|
# DPO specific parameters
|
||||||
beta=dpo_config.beta,
|
beta=dpo_config.beta,
|
||||||
|
|
@ -383,13 +348,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the DPO training process with signal handling."""
|
"""Run the DPO training process with signal handling."""
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
# Setup signal handlers
|
||||||
"""Handle termination signals gracefully."""
|
setup_signal_handlers()
|
||||||
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Convert config dicts back to objects
|
# Convert config dicts back to objects
|
||||||
logger.info("Initializing configuration objects")
|
logger.info("Initializing configuration objects")
|
||||||
|
|
@ -420,11 +380,11 @@ class HFDPOAlignmentSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model and reference model
|
# Load model and reference model
|
||||||
model_obj = self.load_model(model, device, provider_config_obj)
|
model_obj = load_model(model, device, provider_config_obj)
|
||||||
ref_model = None
|
ref_model = None
|
||||||
if provider_config_obj.use_reference_model:
|
if provider_config_obj.use_reference_model:
|
||||||
logger.info("Loading separate reference model for DPO")
|
logger.info("Loading separate reference model for DPO")
|
||||||
ref_model = self.load_model(model, device, provider_config_obj)
|
ref_model = load_model(model, device, provider_config_obj)
|
||||||
else:
|
else:
|
||||||
logger.info("Using shared reference model for DPO")
|
logger.info("Using shared reference model for DPO")
|
||||||
|
|
||||||
|
|
@ -496,9 +456,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
# Train in a separate process
|
# Train in a separate process
|
||||||
logger.info("Starting DPO training in separate process")
|
logger.info("Starting DPO training in separate process")
|
||||||
try:
|
try:
|
||||||
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
|
# Setup multiprocessing for device
|
||||||
if device.type in ["cuda", "mps"]:
|
setup_multiprocessing_for_device(device)
|
||||||
multiprocessing.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
process = multiprocessing.Process(
|
process = multiprocessing.Process(
|
||||||
target=self._run_training_sync,
|
target=self._run_training_sync,
|
||||||
|
|
@ -526,37 +485,7 @@ class HFDPOAlignmentSingleDevice:
|
||||||
|
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
if output_dir_path:
|
if output_dir_path:
|
||||||
# Get all checkpoint directories and sort them numerically
|
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
|
||||||
checkpoint_dirs = sorted(
|
|
||||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
|
||||||
key=lambda x: int(x.name.split("-")[1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add all checkpoint directories
|
|
||||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
|
||||||
# Get the creation time of the directory
|
|
||||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
|
||||||
|
|
||||||
checkpoint = Checkpoint(
|
|
||||||
identifier=checkpoint_dir.name,
|
|
||||||
created_at=created_time,
|
|
||||||
epoch=epoch_number,
|
|
||||||
post_training_job_id=job_uuid,
|
|
||||||
path=str(checkpoint_dir),
|
|
||||||
)
|
|
||||||
checkpoints.append(checkpoint)
|
|
||||||
|
|
||||||
# Add the DPO model as a checkpoint
|
|
||||||
dpo_model_path = output_dir_path / "dpo_model"
|
|
||||||
if dpo_model_path.exists():
|
|
||||||
checkpoint = Checkpoint(
|
|
||||||
identifier=f"{model}-dpo-{config.n_epochs}",
|
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
epoch=config.n_epochs,
|
|
||||||
post_training_job_id=job_uuid,
|
|
||||||
path=str(dpo_model_path),
|
|
||||||
)
|
|
||||||
checkpoints.append(checkpoint)
|
|
||||||
|
|
||||||
return memory_stats, checkpoints if checkpoints else None
|
return memory_stats, checkpoints if checkpoints else None
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||||
|
|
||||||
|
from .config import HuggingFacePostTrainingConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_environment():
|
def setup_environment():
|
||||||
|
|
@ -100,7 +115,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
|
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
"""Load dataset from llama stack dataset provider"""
|
"""Load dataset from llama stack dataset provider"""
|
||||||
try:
|
try:
|
||||||
all_rows = await datasetio_api.iterrows(
|
all_rows = await datasetio_api.iterrows(
|
||||||
|
|
@ -112,3 +127,153 @@ async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
return all_rows.data
|
return all_rows.data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model: str,
|
||||||
|
device: torch.device,
|
||||||
|
provider_config: HuggingFacePostTrainingConfig,
|
||||||
|
) -> AutoModelForCausalLM:
|
||||||
|
"""Load and initialize the model for training.
|
||||||
|
Args:
|
||||||
|
model: The model identifier to load
|
||||||
|
device: The device to load the model onto
|
||||||
|
provider_config: Provider-specific configuration
|
||||||
|
Returns:
|
||||||
|
The loaded and initialized model
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
logger.info("Loading the base model")
|
||||||
|
try:
|
||||||
|
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||||
|
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model,
|
||||||
|
torch_dtype="auto" if device.type != "cpu" else "float32",
|
||||||
|
quantization_config=None,
|
||||||
|
config=model_config,
|
||||||
|
**provider_config.model_specific_config,
|
||||||
|
)
|
||||||
|
# Always move model to specified device
|
||||||
|
model_obj = model_obj.to(device)
|
||||||
|
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||||
|
return model_obj
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
|
||||||
|
"""Split dataset into train and validation sets.
|
||||||
|
Args:
|
||||||
|
ds: Dataset to split
|
||||||
|
Returns:
|
||||||
|
tuple: (train_dataset, eval_dataset)
|
||||||
|
"""
|
||||||
|
logger.info("Splitting dataset into train and validation sets")
|
||||||
|
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||||
|
train_dataset = train_val_split["train"]
|
||||||
|
eval_dataset = train_val_split["test"]
|
||||||
|
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def setup_signal_handlers():
|
||||||
|
"""Setup signal handlers for graceful shutdown."""
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
|
||||||
|
"""Calculate training steps and logging configuration.
|
||||||
|
Args:
|
||||||
|
steps_per_epoch: Number of training steps per epoch
|
||||||
|
config: Training configuration
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary with calculated step values
|
||||||
|
"""
|
||||||
|
total_steps = steps_per_epoch * config.n_epochs
|
||||||
|
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||||
|
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||||
|
|
||||||
|
logger.info("Training configuration:")
|
||||||
|
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||||
|
logger.info(f"- Total steps: {total_steps}")
|
||||||
|
logger.info(f"- Max steps: {max_steps}")
|
||||||
|
logger.info(f"- Logging steps: {logging_steps}")
|
||||||
|
|
||||||
|
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
|
||||||
|
|
||||||
|
|
||||||
|
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
|
||||||
|
"""Get save and evaluation strategy based on output directory.
|
||||||
|
Args:
|
||||||
|
output_dir_path: Optional path to save the model
|
||||||
|
Returns:
|
||||||
|
tuple: (save_strategy, eval_strategy)
|
||||||
|
"""
|
||||||
|
if output_dir_path:
|
||||||
|
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||||
|
return "epoch", "epoch"
|
||||||
|
return "no", "no"
|
||||||
|
|
||||||
|
|
||||||
|
def create_checkpoints(
|
||||||
|
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
|
||||||
|
) -> list[Checkpoint]:
|
||||||
|
"""Create checkpoint objects from training output.
|
||||||
|
Args:
|
||||||
|
output_dir_path: Path to the training output directory
|
||||||
|
job_uuid: Unique identifier for the training job
|
||||||
|
model: Model identifier
|
||||||
|
config: Training configuration
|
||||||
|
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
|
||||||
|
Returns:
|
||||||
|
List of Checkpoint objects
|
||||||
|
"""
|
||||||
|
checkpoints = []
|
||||||
|
|
||||||
|
# Add checkpoint directories
|
||||||
|
checkpoint_dirs = sorted(
|
||||||
|
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
||||||
|
key=lambda x: int(x.name.split("-")[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
||||||
|
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
identifier=checkpoint_dir.name,
|
||||||
|
created_at=created_time,
|
||||||
|
epoch=epoch_number,
|
||||||
|
post_training_job_id=job_uuid,
|
||||||
|
path=str(checkpoint_dir),
|
||||||
|
)
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
|
# Add final model
|
||||||
|
final_model_path = output_dir_path / final_model_name
|
||||||
|
if final_model_path.exists():
|
||||||
|
training_type = "sft" if final_model_name == "merged_model" else "dpo"
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
identifier=f"{model}-{training_type}-{config.n_epochs}",
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
epoch=config.n_epochs,
|
||||||
|
post_training_job_id=job_uuid,
|
||||||
|
path=str(final_model_path),
|
||||||
|
)
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def setup_multiprocessing_for_device(device: torch.device):
|
||||||
|
"""Setup multiprocessing start method based on device type.
|
||||||
|
Args:
|
||||||
|
device: The device being used for training
|
||||||
|
"""
|
||||||
|
if device.type in ["cuda", "mps"]:
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue