mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
removed more redunant code
This commit is contained in:
parent
41f4678faf
commit
736404c1bd
3 changed files with 215 additions and 202 deletions
|
@ -4,11 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
@ -100,7 +115,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|||
return device
|
||||
|
||||
|
||||
async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
|
||||
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await datasetio_api.iterrows(
|
||||
|
@ -112,3 +127,153 @@ async def setup_data(datasetio_api, dataset_id: str) -> list[dict[str, Any]]:
|
|||
return all_rows.data
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||
|
||||
|
||||
def load_model(
|
||||
model: str,
|
||||
device: torch.device,
|
||||
provider_config: HuggingFacePostTrainingConfig,
|
||||
) -> AutoModelForCausalLM:
|
||||
"""Load and initialize the model for training.
|
||||
Args:
|
||||
model: The model identifier to load
|
||||
device: The device to load the model onto
|
||||
provider_config: Provider-specific configuration
|
||||
Returns:
|
||||
The loaded and initialized model
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
logger.info("Loading the base model")
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||
model,
|
||||
torch_dtype="auto" if device.type != "cpu" else "float32",
|
||||
quantization_config=None,
|
||||
config=model_config,
|
||||
**provider_config.model_specific_config,
|
||||
)
|
||||
# Always move model to specified device
|
||||
model_obj = model_obj.to(device)
|
||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||
return model_obj
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||
|
||||
|
||||
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
|
||||
"""Split dataset into train and validation sets.
|
||||
Args:
|
||||
ds: Dataset to split
|
||||
Returns:
|
||||
tuple: (train_dataset, eval_dataset)
|
||||
"""
|
||||
logger.info("Splitting dataset into train and validation sets")
|
||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = train_val_split["train"]
|
||||
eval_dataset = train_val_split["test"]
|
||||
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""Setup signal handlers for graceful shutdown."""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
|
||||
"""Calculate training steps and logging configuration.
|
||||
Args:
|
||||
steps_per_epoch: Number of training steps per epoch
|
||||
config: Training configuration
|
||||
Returns:
|
||||
dict: Dictionary with calculated step values
|
||||
"""
|
||||
total_steps = steps_per_epoch * config.n_epochs
|
||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||
|
||||
logger.info("Training configuration:")
|
||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||
logger.info(f"- Total steps: {total_steps}")
|
||||
logger.info(f"- Max steps: {max_steps}")
|
||||
logger.info(f"- Logging steps: {logging_steps}")
|
||||
|
||||
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
|
||||
|
||||
|
||||
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
|
||||
"""Get save and evaluation strategy based on output directory.
|
||||
Args:
|
||||
output_dir_path: Optional path to save the model
|
||||
Returns:
|
||||
tuple: (save_strategy, eval_strategy)
|
||||
"""
|
||||
if output_dir_path:
|
||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||
return "epoch", "epoch"
|
||||
return "no", "no"
|
||||
|
||||
|
||||
def create_checkpoints(
|
||||
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
|
||||
) -> list[Checkpoint]:
|
||||
"""Create checkpoint objects from training output.
|
||||
Args:
|
||||
output_dir_path: Path to the training output directory
|
||||
job_uuid: Unique identifier for the training job
|
||||
model: Model identifier
|
||||
config: Training configuration
|
||||
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
|
||||
Returns:
|
||||
List of Checkpoint objects
|
||||
"""
|
||||
checkpoints = []
|
||||
|
||||
# Add checkpoint directories
|
||||
checkpoint_dirs = sorted(
|
||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
||||
key=lambda x: int(x.name.split("-")[1]),
|
||||
)
|
||||
|
||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=checkpoint_dir.name,
|
||||
created_at=created_time,
|
||||
epoch=epoch_number,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(checkpoint_dir),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# Add final model
|
||||
final_model_path = output_dir_path / final_model_name
|
||||
if final_model_path.exists():
|
||||
training_type = "sft" if final_model_name == "merged_model" else "dpo"
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-{training_type}-{config.n_epochs}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(final_model_path),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def setup_multiprocessing_for_device(device: torch.device):
|
||||
"""Setup multiprocessing start method based on device type.
|
||||
Args:
|
||||
device: The device being used for training
|
||||
"""
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue