fix the chnages that requested in review

This commit is contained in:
Nehanth 2025-07-29 17:22:51 +00:00
parent 41a45580e0
commit 518bf2fc34
6 changed files with 42 additions and 67 deletions

View file

@ -27,6 +27,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | ./checkpoints/dpo | |
## Sample Configuration

View file

@ -71,6 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = "./checkpoints/dpo"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:

View file

@ -132,12 +132,9 @@ class HuggingFacePostTrainingImpl:
datasets_api=self.datasets_api,
)
# Use default checkpoint directory
output_dir = f"./checkpoints/dpo/{job_uuid}"
resources_allocated, checkpoints = await recipe.train(
model=finetuned_model,
output_dir=output_dir,
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
job_uuid=job_uuid,
dpo_config=algorithm_config,
config=training_config,

View file

@ -8,20 +8,9 @@ import gc
import json
import logging
import multiprocessing
import os
from pathlib import Path
from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
@ -47,8 +37,8 @@ from ..utils import (
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
) -> None:
"""Run the training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
logger.info("Starting training in separate process")
try:
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,

View file

@ -7,20 +7,9 @@
import gc
import logging
import multiprocessing
import os
from pathlib import Path
from typing import Any
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from transformers import (
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
@ -43,8 +33,8 @@ from ..utils import (
get_memory_stats,
get_save_strategy,
load_model,
setup_data,
setup_multiprocessing_for_device,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
def validate_dataset_format(self, rows: list[dict]) -> None:
"""Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"]
if not rows:
logger.warning("Dataset is empty")
return False
raise ValueError("Dataset is empty")
for i, row in enumerate(rows):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary")
return False
raise ValueError(f"Row {i} is not a dictionary")
for field in required_fields:
if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}")
return False
raise ValueError(f"Row {i} missing required DPO field: {field}")
# Handle both string and list formats
if field == "prompt":
# Prompt should be a string
if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string")
return False
raise ValueError(f"Row {i} field '{field}' is not a string")
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
raise ValueError(f"Row {i} field '{field}' is empty")
else:
# chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str):
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
return False
raise ValueError(f"Row {i} field '{field}' is empty")
elif isinstance(row[field], list):
if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list")
return False
raise ValueError(f"Row {i} field '{field}' is empty list")
else:
logger.warning(f"Row {i} field '{field}' is neither string nor list")
return False
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
return True
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats."""
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
self.validate_dataset_format(rows)
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
) -> None:
"""Run the DPO training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
logger.info("Starting DPO training in separate process")
try:
# Setup multiprocessing for device
setup_multiprocessing_for_device(device)
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
import multiprocessing
import os
import signal
import sys
@ -34,7 +33,7 @@ def setup_environment():
os.environ["MKL_NUM_THREADS"] = "1"
def get_gb(to_convert: int) -> str:
def bytes_to_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"total": bytes_to_gb(psutil.virtual_memory().total),
"available": bytes_to_gb(psutil.virtual_memory().available),
"used": bytes_to_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_rss": bytes_to_gb(process.memory_info().rss),
"process_vms": bytes_to_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
return device
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
@ -268,12 +267,3 @@ def create_checkpoints(
checkpoints.append(checkpoint)
return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)