fix the chnages that requested in review

This commit is contained in:
Nehanth 2025-07-29 17:22:51 +00:00
parent 41a45580e0
commit 518bf2fc34
6 changed files with 42 additions and 67 deletions

View file

@ -27,6 +27,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
| `dpo_beta` | `<class 'float'>` | No | 0.1 | | | `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | | | `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | | `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | ./checkpoints/dpo | |
## Sample Configuration ## Sample Configuration

View file

@ -71,6 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1 dpo_beta: float = 0.1
use_reference_model: bool = True use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid" dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = "./checkpoints/dpo"
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:

View file

@ -132,12 +132,9 @@ class HuggingFacePostTrainingImpl:
datasets_api=self.datasets_api, datasets_api=self.datasets_api,
) )
# Use default checkpoint directory
output_dir = f"./checkpoints/dpo/{job_uuid}"
resources_allocated, checkpoints = await recipe.train( resources_allocated, checkpoints = await recipe.train(
model=finetuned_model, model=finetuned_model,
output_dir=output_dir, output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
job_uuid=job_uuid, job_uuid=job_uuid,
dpo_config=algorithm_config, dpo_config=algorithm_config,
config=training_config, config=training_config,

View file

@ -8,20 +8,9 @@ import gc
import json import json
import logging import logging
import multiprocessing import multiprocessing
import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch import torch
from datasets import Dataset from datasets import Dataset
from peft import LoraConfig from peft import LoraConfig
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig, LoraFinetuningConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
from ..utils import ( from ..utils import (
@ -47,8 +37,8 @@ from ..utils import (
get_memory_stats, get_memory_stats,
get_save_strategy, get_save_strategy,
load_model, load_model,
setup_data, load_rows_from_dataset,
setup_multiprocessing_for_device, setup_environment,
setup_signal_handlers, setup_signal_handlers,
setup_torch_device, setup_torch_device,
split_dataset, split_dataset,
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
# Load dataset # Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}") logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id) rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows): if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset") logger.info(f"Loaded {len(rows)} rows from dataset")
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
) -> None: ) -> None:
"""Run the training process with signal handling.""" """Run the training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers # Setup signal handlers
setup_signal_handlers() setup_signal_handlers()
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
logger.info("Starting training in separate process") logger.info("Starting training in separate process")
try: try:
# Setup multiprocessing for device # Setup multiprocessing for device
setup_multiprocessing_for_device(device) if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process( process = multiprocessing.Process(
target=self._run_training_sync, target=self._run_training_sync,

View file

@ -7,20 +7,9 @@
import gc import gc
import logging import logging
import multiprocessing import multiprocessing
import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch import torch
from datasets import Dataset from datasets import Dataset
from transformers import ( from transformers import (
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig, DPOAlignmentConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
from ..utils import ( from ..utils import (
@ -43,8 +33,8 @@ from ..utils import (
get_memory_stats, get_memory_stats,
get_save_strategy, get_save_strategy,
load_model, load_model,
setup_data, load_rows_from_dataset,
setup_multiprocessing_for_device, setup_environment,
setup_signal_handlers, setup_signal_handlers,
setup_torch_device, setup_torch_device,
split_dataset, split_dataset,
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.job_uuid = job_uuid self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool: def validate_dataset_format(self, rows: list[dict]) -> None:
"""Validate that the dataset has the required fields for DPO training.""" """Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"] required_fields = ["prompt", "chosen", "rejected"]
if not rows: if not rows:
logger.warning("Dataset is empty") logger.warning("Dataset is empty")
return False raise ValueError("Dataset is empty")
for i, row in enumerate(rows): for i, row in enumerate(rows):
if not isinstance(row, dict): if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary") logger.warning(f"Row {i} is not a dictionary")
return False raise ValueError(f"Row {i} is not a dictionary")
for field in required_fields: for field in required_fields:
if field not in row: if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}") logger.warning(f"Row {i} missing required DPO field: {field}")
return False raise ValueError(f"Row {i} missing required DPO field: {field}")
# Handle both string and list formats # Handle both string and list formats
if field == "prompt": if field == "prompt":
# Prompt should be a string # Prompt should be a string
if not isinstance(row[field], str): if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string") logger.warning(f"Row {i} field '{field}' is not a string")
return False raise ValueError(f"Row {i} field '{field}' is not a string")
if not row[field].strip(): if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty") logger.warning(f"Row {i} field '{field}' is empty")
return False raise ValueError(f"Row {i} field '{field}' is empty")
else: else:
# chosen/rejected can be either strings or lists of messages # chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str): if isinstance(row[field], str):
if not row[field].strip(): if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty") logger.warning(f"Row {i} field '{field}' is empty")
return False raise ValueError(f"Row {i} field '{field}' is empty")
elif isinstance(row[field], list): elif isinstance(row[field], list):
if not row[field]: if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list") logger.warning(f"Row {i} field '{field}' is empty list")
return False raise ValueError(f"Row {i} field '{field}' is empty list")
else: else:
logger.warning(f"Row {i} field '{field}' is neither string nor list") logger.warning(f"Row {i} field '{field}' is neither string nor list")
return False raise ValueError(f"Row {i} field '{field}' is neither string nor list")
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples") logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
return True
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]: def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats.""" """Process a row in DPO format, handling both string and conversation list formats."""
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
# Load dataset # Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}") logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id) rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows): self.validate_dataset_format(rows)
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
logger.info(f"Loaded {len(rows)} rows from dataset") logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer # Initialize tokenizer
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
) -> None: ) -> None:
"""Run the DPO training process with signal handling.""" """Run the DPO training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers # Setup signal handlers
setup_signal_handlers() setup_signal_handlers()
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
logger.info("Starting DPO training in separate process") logger.info("Starting DPO training in separate process")
try: try:
# Setup multiprocessing for device # Setup multiprocessing for device
setup_multiprocessing_for_device(device) if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process( process = multiprocessing.Process(
target=self._run_training_sync, target=self._run_training_sync,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import multiprocessing
import os import os
import signal import signal
import sys import sys
@ -34,7 +33,7 @@ def setup_environment():
os.environ["MKL_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1"
def get_gb(to_convert: int) -> str: def bytes_to_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places. """Converts memory stats to GB and formats to 2 decimal places.
Args: Args:
to_convert: Memory value in bytes to_convert: Memory value in bytes
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device.""" """Get memory statistics for the given device."""
stats = { stats = {
"system_memory": { "system_memory": {
"total": get_gb(psutil.virtual_memory().total), "total": bytes_to_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available), "available": bytes_to_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used), "used": bytes_to_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent, "percent": psutil.virtual_memory().percent,
} }
} }
if device.type == "cuda": if device.type == "cuda":
stats["device_memory"] = { stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)), "allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)), "reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), "max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
} }
elif device.type == "mps": elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory # MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = { stats["device_memory"] = {
"note": "MPS memory stats not directly available", "note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used), "system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
} }
elif device.type == "cpu": elif device.type == "cpu":
# For CPU, we track process memory usage # For CPU, we track process memory usage
process = psutil.Process() process = psutil.Process()
stats["device_memory"] = { stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss), "process_rss": bytes_to_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms), "process_vms": bytes_to_gb(process.memory_info().vms),
"process_percent": process.memory_percent(), "process_percent": process.memory_percent(),
} }
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
return device return device
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider""" """Load dataset from llama stack dataset provider"""
try: try:
all_rows = await datasetio_api.iterrows( all_rows = await datasetio_api.iterrows(
@ -268,12 +267,3 @@ def create_checkpoints(
checkpoints.append(checkpoint) checkpoints.append(checkpoint)
return checkpoints return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)