mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	# What does this PR do? This PR adds a step in pre-commit to enforce using `llama_stack` logger. Currently, various parts of the code base uses different loggers. As a custom `llama_stack` logger exist and used in the codebase, it is better to standardize its utilization. Signed-off-by: Mustafa Elbehery <melbeher@redhat.com> Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>
		
			
				
	
	
		
			269 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			269 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import os
 | |
| import signal
 | |
| import sys
 | |
| from datetime import UTC, datetime
 | |
| from pathlib import Path
 | |
| from typing import Any
 | |
| 
 | |
| import psutil
 | |
| import torch
 | |
| from datasets import Dataset
 | |
| from transformers import AutoConfig, AutoModelForCausalLM
 | |
| 
 | |
| from llama_stack.apis.datasetio import DatasetIO
 | |
| from llama_stack.apis.post_training import Checkpoint, TrainingConfig
 | |
| from llama_stack.log import get_logger
 | |
| 
 | |
| from .config import HuggingFacePostTrainingConfig
 | |
| 
 | |
| logger = get_logger(name=__name__, category="post_training")
 | |
| 
 | |
| 
 | |
| def setup_environment():
 | |
|     """Setup common environment variables for training."""
 | |
|     os.environ["TOKENIZERS_PARALLELISM"] = "false"
 | |
|     os.environ["MKL_THREADING_LAYER"] = "GNU"
 | |
|     os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
 | |
|     os.environ["MKL_NUM_THREADS"] = "1"
 | |
| 
 | |
| 
 | |
| def bytes_to_gb(to_convert: int) -> str:
 | |
|     """Converts memory stats to GB and formats to 2 decimal places.
 | |
|     Args:
 | |
|         to_convert: Memory value in bytes
 | |
|     Returns:
 | |
|         str: Memory value in GB formatted to 2 decimal places
 | |
|     """
 | |
|     return f"{(to_convert / (1024**3)):.2f}"
 | |
| 
 | |
| 
 | |
| def get_memory_stats(device: torch.device) -> dict[str, Any]:
 | |
|     """Get memory statistics for the given device."""
 | |
|     stats = {
 | |
|         "system_memory": {
 | |
|             "total": bytes_to_gb(psutil.virtual_memory().total),
 | |
|             "available": bytes_to_gb(psutil.virtual_memory().available),
 | |
|             "used": bytes_to_gb(psutil.virtual_memory().used),
 | |
|             "percent": psutil.virtual_memory().percent,
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     if device.type == "cuda":
 | |
|         stats["device_memory"] = {
 | |
|             "allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
 | |
|             "reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
 | |
|             "max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
 | |
|         }
 | |
|     elif device.type == "mps":
 | |
|         # MPS doesn't provide direct memory stats, but we can track system memory
 | |
|         stats["device_memory"] = {
 | |
|             "note": "MPS memory stats not directly available",
 | |
|             "system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
 | |
|         }
 | |
|     elif device.type == "cpu":
 | |
|         # For CPU, we track process memory usage
 | |
|         process = psutil.Process()
 | |
|         stats["device_memory"] = {
 | |
|             "process_rss": bytes_to_gb(process.memory_info().rss),
 | |
|             "process_vms": bytes_to_gb(process.memory_info().vms),
 | |
|             "process_percent": process.memory_percent(),
 | |
|         }
 | |
| 
 | |
|     return stats
 | |
| 
 | |
| 
 | |
| def setup_torch_device(device_str: str) -> torch.device:
 | |
|     """Initialize and validate a PyTorch device.
 | |
|     This function handles device initialization and validation for different device types:
 | |
|     - CUDA: Validates CUDA availability and handles device selection
 | |
|     - MPS: Validates MPS availability for Apple Silicon
 | |
|     - CPU: Basic validation
 | |
|     - HPU: Raises error as it's not supported
 | |
|     Args:
 | |
|         device_str: String specifying the device ('cuda', 'cpu', 'mps')
 | |
|     Returns:
 | |
|         torch.device: The initialized and validated device
 | |
|     Raises:
 | |
|         RuntimeError: If device initialization fails or device is not supported
 | |
|     """
 | |
|     try:
 | |
|         device = torch.device(device_str)
 | |
|     except RuntimeError as e:
 | |
|         raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
 | |
| 
 | |
|     # Validate device capabilities
 | |
|     if device.type == "cuda":
 | |
|         if not torch.cuda.is_available():
 | |
|             raise RuntimeError(
 | |
|                 f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
 | |
|             )
 | |
|         if device.index is None:
 | |
|             device = torch.device(device.type, torch.cuda.current_device())
 | |
|     elif device.type == "mps":
 | |
|         if not torch.backends.mps.is_available():
 | |
|             raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
 | |
|     elif device.type == "hpu":
 | |
|         raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
 | |
| 
 | |
|     return device
 | |
| 
 | |
| 
 | |
| async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
 | |
|     """Load dataset from llama stack dataset provider"""
 | |
|     try:
 | |
|         all_rows = await datasetio_api.iterrows(
 | |
|             dataset_id=dataset_id,
 | |
|             limit=-1,
 | |
|         )
 | |
|         if not isinstance(all_rows.data, list):
 | |
|             raise RuntimeError("Expected dataset data to be a list")
 | |
|         return all_rows.data
 | |
|     except Exception as e:
 | |
|         raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
 | |
| 
 | |
| 
 | |
| def load_model(
 | |
|     model: str,
 | |
|     device: torch.device,
 | |
|     provider_config: HuggingFacePostTrainingConfig,
 | |
| ) -> AutoModelForCausalLM:
 | |
|     """Load and initialize the model for training.
 | |
|     Args:
 | |
|         model: The model identifier to load
 | |
|         device: The device to load the model onto
 | |
|         provider_config: Provider-specific configuration
 | |
|     Returns:
 | |
|         The loaded and initialized model
 | |
|     Raises:
 | |
|         RuntimeError: If model loading fails
 | |
|     """
 | |
|     logger.info("Loading the base model")
 | |
|     try:
 | |
|         model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
 | |
|         model_obj = AutoModelForCausalLM.from_pretrained(
 | |
|             model,
 | |
|             torch_dtype="auto" if device.type != "cpu" else "float32",
 | |
|             quantization_config=None,
 | |
|             config=model_config,
 | |
|             **provider_config.model_specific_config,
 | |
|         )
 | |
|         # Always move model to specified device
 | |
|         model_obj = model_obj.to(device)
 | |
|         logger.info(f"Model loaded and moved to device: {model_obj.device}")
 | |
|         return model_obj
 | |
|     except Exception as e:
 | |
|         raise RuntimeError(f"Failed to load model: {str(e)}") from e
 | |
| 
 | |
| 
 | |
| def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
 | |
|     """Split dataset into train and validation sets.
 | |
|     Args:
 | |
|         ds: Dataset to split
 | |
|     Returns:
 | |
|         tuple: (train_dataset, eval_dataset)
 | |
|     """
 | |
|     logger.info("Splitting dataset into train and validation sets")
 | |
|     train_val_split = ds.train_test_split(test_size=0.1, seed=42)
 | |
|     train_dataset = train_val_split["train"]
 | |
|     eval_dataset = train_val_split["test"]
 | |
|     logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
 | |
|     return train_dataset, eval_dataset
 | |
| 
 | |
| 
 | |
| def setup_signal_handlers():
 | |
|     """Setup signal handlers for graceful shutdown."""
 | |
| 
 | |
|     def signal_handler(signum, frame):
 | |
|         logger.info(f"Received signal {signum}, initiating graceful shutdown")
 | |
|         sys.exit(0)
 | |
| 
 | |
|     signal.signal(signal.SIGTERM, signal_handler)
 | |
|     signal.signal(signal.SIGINT, signal_handler)
 | |
| 
 | |
| 
 | |
| def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
 | |
|     """Calculate training steps and logging configuration.
 | |
|     Args:
 | |
|         steps_per_epoch: Number of training steps per epoch
 | |
|         config: Training configuration
 | |
|     Returns:
 | |
|         dict: Dictionary with calculated step values
 | |
|     """
 | |
|     total_steps = steps_per_epoch * config.n_epochs
 | |
|     max_steps = min(config.max_steps_per_epoch, total_steps)
 | |
|     logging_steps = max(1, steps_per_epoch // 50)  # Log 50 times per epoch
 | |
| 
 | |
|     logger.info("Training configuration:")
 | |
|     logger.info(f"- Steps per epoch: {steps_per_epoch}")
 | |
|     logger.info(f"- Total steps: {total_steps}")
 | |
|     logger.info(f"- Max steps: {max_steps}")
 | |
|     logger.info(f"- Logging steps: {logging_steps}")
 | |
| 
 | |
|     return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
 | |
| 
 | |
| 
 | |
| def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
 | |
|     """Get save and evaluation strategy based on output directory.
 | |
|     Args:
 | |
|         output_dir_path: Optional path to save the model
 | |
|     Returns:
 | |
|         tuple: (save_strategy, eval_strategy)
 | |
|     """
 | |
|     if output_dir_path:
 | |
|         logger.info(f"Will save checkpoints to {output_dir_path}")
 | |
|         return "epoch", "epoch"
 | |
|     return "no", "no"
 | |
| 
 | |
| 
 | |
| def create_checkpoints(
 | |
|     output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
 | |
| ) -> list[Checkpoint]:
 | |
|     """Create checkpoint objects from training output.
 | |
|     Args:
 | |
|         output_dir_path: Path to the training output directory
 | |
|         job_uuid: Unique identifier for the training job
 | |
|         model: Model identifier
 | |
|         config: Training configuration
 | |
|         final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
 | |
|     Returns:
 | |
|         List of Checkpoint objects
 | |
|     """
 | |
|     checkpoints = []
 | |
| 
 | |
|     # Add checkpoint directories
 | |
|     checkpoint_dirs = sorted(
 | |
|         [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
 | |
|         key=lambda x: int(x.name.split("-")[1]),
 | |
|     )
 | |
| 
 | |
|     for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
 | |
|         created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
 | |
|         checkpoint = Checkpoint(
 | |
|             identifier=checkpoint_dir.name,
 | |
|             created_at=created_time,
 | |
|             epoch=epoch_number,
 | |
|             post_training_job_id=job_uuid,
 | |
|             path=str(checkpoint_dir),
 | |
|         )
 | |
|         checkpoints.append(checkpoint)
 | |
| 
 | |
|     # Add final model
 | |
|     final_model_path = output_dir_path / final_model_name
 | |
|     if final_model_path.exists():
 | |
|         training_type = "sft" if final_model_name == "merged_model" else "dpo"
 | |
|         checkpoint = Checkpoint(
 | |
|             identifier=f"{model}-{training_type}-{config.n_epochs}",
 | |
|             created_at=datetime.now(UTC),
 | |
|             epoch=config.n_epochs,
 | |
|             post_training_job_id=job_uuid,
 | |
|             path=str(final_model_path),
 | |
|         )
 | |
|         checkpoints.append(checkpoint)
 | |
| 
 | |
|     return checkpoints
 |