diff --git a/docs/source/providers/post_training/inline_huggingface.md b/docs/source/providers/post_training/inline_huggingface.md index 53025b233..0a8745e71 100644 --- a/docs/source/providers/post_training/inline_huggingface.md +++ b/docs/source/providers/post_training/inline_huggingface.md @@ -27,6 +27,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin | `dpo_beta` | `` | No | 0.1 | | | `use_reference_model` | `` | No | True | | | `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | ./checkpoints/dpo | | ## Sample Configuration diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py index 74733be09..dae8fcc04 100644 --- a/llama_stack/providers/inline/post_training/huggingface/config.py +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -71,6 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel): dpo_beta: float = 0.1 use_reference_model: bool = True dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid" + dpo_output_dir: str = "./checkpoints/dpo" @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py index 0160c6267..81622e2b7 100644 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -132,12 +132,9 @@ class HuggingFacePostTrainingImpl: datasets_api=self.datasets_api, ) - # Use default checkpoint directory - output_dir = f"./checkpoints/dpo/{job_uuid}" - resources_allocated, checkpoints = await recipe.train( model=finetuned_model, - output_dir=output_dir, + output_dir=f"{self.config.dpo_output_dir}/{job_uuid}", job_uuid=job_uuid, dpo_config=algorithm_config, config=training_config, diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 935e5ff15..6853ee11a 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -8,20 +8,9 @@ import gc import json import logging import multiprocessing -import os from pathlib import Path from typing import Any -from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device - -# Set tokenizer parallelism environment variable -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -# Force PyTorch to use OpenBLAS instead of MKL -os.environ["MKL_THREADING_LAYER"] = "GNU" -os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" -os.environ["MKL_NUM_THREADS"] = "1" - import torch from datasets import Dataset from peft import LoraConfig @@ -39,6 +28,7 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig from ..utils import ( @@ -47,8 +37,8 @@ from ..utils import ( get_memory_stats, get_save_strategy, load_model, - setup_data, - setup_multiprocessing_for_device, + load_rows_from_dataset, + setup_environment, setup_signal_handlers, setup_torch_device, split_dataset, @@ -239,7 +229,7 @@ class HFFinetuningSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await setup_data(self.datasetio_api, config.data_config.dataset_id) + rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") logger.info(f"Loaded {len(rows)} rows from dataset") @@ -383,6 +373,9 @@ class HFFinetuningSingleDevice: ) -> None: """Run the training process with signal handling.""" + # Setup environment variables + setup_environment() + # Setup signal handlers setup_signal_handlers() @@ -489,7 +482,8 @@ class HFFinetuningSingleDevice: logger.info("Starting training in separate process") try: # Setup multiprocessing for device - setup_multiprocessing_for_device(device) + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) process = multiprocessing.Process( target=self._run_training_sync, diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index 25d36afd7..a7c19faac 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -7,20 +7,9 @@ import gc import logging import multiprocessing -import os from pathlib import Path from typing import Any -from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device - -# Set tokenizer parallelism environment variable -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -# Force PyTorch to use OpenBLAS instead of MKL -os.environ["MKL_THREADING_LAYER"] = "GNU" -os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" -os.environ["MKL_NUM_THREADS"] = "1" - import torch from datasets import Dataset from transformers import ( @@ -35,6 +24,7 @@ from llama_stack.apis.post_training import ( DPOAlignmentConfig, TrainingConfig, ) +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig from ..utils import ( @@ -43,8 +33,8 @@ from ..utils import ( get_memory_stats, get_save_strategy, load_model, - setup_data, - setup_multiprocessing_for_device, + load_rows_from_dataset, + setup_environment, setup_signal_handlers, setup_torch_device, split_dataset, @@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice: self.datasets_api = datasets_api self.job_uuid = job_uuid - def validate_dataset_format(self, rows: list[dict]) -> bool: + def validate_dataset_format(self, rows: list[dict]) -> None: """Validate that the dataset has the required fields for DPO training.""" required_fields = ["prompt", "chosen", "rejected"] if not rows: logger.warning("Dataset is empty") - return False + raise ValueError("Dataset is empty") for i, row in enumerate(rows): if not isinstance(row, dict): logger.warning(f"Row {i} is not a dictionary") - return False + raise ValueError(f"Row {i} is not a dictionary") for field in required_fields: if field not in row: logger.warning(f"Row {i} missing required DPO field: {field}") - return False + raise ValueError(f"Row {i} missing required DPO field: {field}") # Handle both string and list formats if field == "prompt": # Prompt should be a string if not isinstance(row[field], str): logger.warning(f"Row {i} field '{field}' is not a string") - return False + raise ValueError(f"Row {i} field '{field}' is not a string") if not row[field].strip(): logger.warning(f"Row {i} field '{field}' is empty") - return False + raise ValueError(f"Row {i} field '{field}' is empty") else: # chosen/rejected can be either strings or lists of messages if isinstance(row[field], str): if not row[field].strip(): logger.warning(f"Row {i} field '{field}' is empty") - return False + raise ValueError(f"Row {i} field '{field}' is empty") elif isinstance(row[field], list): if not row[field]: logger.warning(f"Row {i} field '{field}' is empty list") - return False + raise ValueError(f"Row {i} field '{field}' is empty list") else: logger.warning(f"Row {i} field '{field}' is neither string nor list") - return False + raise ValueError(f"Row {i} field '{field}' is neither string nor list") logger.info(f"DPO dataset validation passed: {len(rows)} preference examples") - return True def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]: """Process a row in DPO format, handling both string and conversation list formats.""" @@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await setup_data(self.datasetio_api, config.data_config.dataset_id) - if not self.validate_dataset_format(rows): - raise ValueError("Dataset is missing required fields: prompt, chosen, rejected") + rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) + self.validate_dataset_format(rows) logger.info(f"Loaded {len(rows)} rows from dataset") # Initialize tokenizer @@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice: ) -> None: """Run the DPO training process with signal handling.""" + # Setup environment variables + setup_environment() + # Setup signal handlers setup_signal_handlers() @@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice: logger.info("Starting DPO training in separate process") try: # Setup multiprocessing for device - setup_multiprocessing_for_device(device) + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) process = multiprocessing.Process( target=self._run_training_sync, diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index cd75ee06b..3147c19ab 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import logging -import multiprocessing import os import signal import sys @@ -34,7 +33,7 @@ def setup_environment(): os.environ["MKL_NUM_THREADS"] = "1" -def get_gb(to_convert: int) -> str: +def bytes_to_gb(to_convert: int) -> str: """Converts memory stats to GB and formats to 2 decimal places. Args: to_convert: Memory value in bytes @@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]: """Get memory statistics for the given device.""" stats = { "system_memory": { - "total": get_gb(psutil.virtual_memory().total), - "available": get_gb(psutil.virtual_memory().available), - "used": get_gb(psutil.virtual_memory().used), + "total": bytes_to_gb(psutil.virtual_memory().total), + "available": bytes_to_gb(psutil.virtual_memory().available), + "used": bytes_to_gb(psutil.virtual_memory().used), "percent": psutil.virtual_memory().percent, } } if device.type == "cuda": stats["device_memory"] = { - "allocated": get_gb(torch.cuda.memory_allocated(device)), - "reserved": get_gb(torch.cuda.memory_reserved(device)), - "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + "allocated": bytes_to_gb(torch.cuda.memory_allocated(device)), + "reserved": bytes_to_gb(torch.cuda.memory_reserved(device)), + "max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)), } elif device.type == "mps": # MPS doesn't provide direct memory stats, but we can track system memory stats["device_memory"] = { "note": "MPS memory stats not directly available", - "system_memory_used": get_gb(psutil.virtual_memory().used), + "system_memory_used": bytes_to_gb(psutil.virtual_memory().used), } elif device.type == "cpu": # For CPU, we track process memory usage process = psutil.Process() stats["device_memory"] = { - "process_rss": get_gb(process.memory_info().rss), - "process_vms": get_gb(process.memory_info().vms), + "process_rss": bytes_to_gb(process.memory_info().rss), + "process_vms": bytes_to_gb(process.memory_info().vms), "process_percent": process.memory_percent(), } @@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device: return device -async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: +async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: """Load dataset from llama stack dataset provider""" try: all_rows = await datasetio_api.iterrows( @@ -268,12 +267,3 @@ def create_checkpoints( checkpoints.append(checkpoint) return checkpoints - - -def setup_multiprocessing_for_device(device: torch.device): - """Setup multiprocessing start method based on device type. - Args: - device: The device being used for training - """ - if device.type in ["cuda", "mps"]: - multiprocessing.set_start_method("spawn", force=True)