fix the chnages that requested in review

This commit is contained in:
Nehanth 2025-07-29 17:22:51 +00:00
parent 41a45580e0
commit 518bf2fc34
6 changed files with 42 additions and 67 deletions

View file

@ -8,20 +8,9 @@ import gc
import json
import logging
import multiprocessing
import os
from pathlib import Path
from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
@ -47,8 +37,8 @@ from ..utils import (
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
) -> None:
"""Run the training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
logger.info("Starting training in separate process")
try:
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,

View file

@ -7,20 +7,9 @@
import gc
import logging
import multiprocessing
import os
from pathlib import Path
from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from transformers import (
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
@ -43,8 +33,8 @@ from ..utils import (
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
def validate_dataset_format(self, rows: list[dict]) -> None:
"""Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"]
if not rows:
logger.warning("Dataset is empty")
return False
raise ValueError("Dataset is empty")
for i, row in enumerate(rows):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary")
return False
raise ValueError(f"Row {i} is not a dictionary")
for field in required_fields:
if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}")
return False
raise ValueError(f"Row {i} missing required DPO field: {field}")
# Handle both string and list formats
if field == "prompt":
# Prompt should be a string
if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string")
return False
raise ValueError(f"Row {i} field '{field}' is not a string")
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
raise ValueError(f"Row {i} field '{field}' is empty")
else:
# chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str):
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
raise ValueError(f"Row {i} field '{field}' is empty")
elif isinstance(row[field], list):
if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list")
return False
raise ValueError(f"Row {i} field '{field}' is empty list")
else:
logger.warning(f"Row {i} field '{field}' is neither string nor list")
return False
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
return True
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats."""
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
self.validate_dataset_format(rows)
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
) -> None:
"""Run the DPO training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
logger.info("Starting DPO training in separate process")
try:
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,