mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-26 22:58:05 +00:00
fix the chnages that requested in review
This commit is contained in:
parent
41a45580e0
commit
518bf2fc34
6 changed files with 42 additions and 67 deletions
|
|
@ -8,20 +8,9 @@ import gc
|
|||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
# Set tokenizer parallelism environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Force PyTorch to use OpenBLAS instead of MKL
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig
|
||||
|
|
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
@ -47,8 +37,8 @@ from ..utils import (
|
|||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
setup_data,
|
||||
setup_multiprocessing_for_device,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
|
|
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
|
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
|
|||
) -> None:
|
||||
"""Run the training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
|
|
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
|
|||
logger.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
setup_multiprocessing_for_device(device)
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
|
|
|
|||
|
|
@ -7,20 +7,9 @@
|
|||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
# Set tokenizer parallelism environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Force PyTorch to use OpenBLAS instead of MKL
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
|
|
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
@ -43,8 +33,8 @@ from ..utils import (
|
|||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
setup_data,
|
||||
setup_multiprocessing_for_device,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
|
|
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||
def validate_dataset_format(self, rows: list[dict]) -> None:
|
||||
"""Validate that the dataset has the required fields for DPO training."""
|
||||
required_fields = ["prompt", "chosen", "rejected"]
|
||||
|
||||
if not rows:
|
||||
logger.warning("Dataset is empty")
|
||||
return False
|
||||
raise ValueError("Dataset is empty")
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if not isinstance(row, dict):
|
||||
logger.warning(f"Row {i} is not a dictionary")
|
||||
return False
|
||||
raise ValueError(f"Row {i} is not a dictionary")
|
||||
|
||||
for field in required_fields:
|
||||
if field not in row:
|
||||
logger.warning(f"Row {i} missing required DPO field: {field}")
|
||||
return False
|
||||
raise ValueError(f"Row {i} missing required DPO field: {field}")
|
||||
|
||||
# Handle both string and list formats
|
||||
if field == "prompt":
|
||||
# Prompt should be a string
|
||||
if not isinstance(row[field], str):
|
||||
logger.warning(f"Row {i} field '{field}' is not a string")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is not a string")
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
else:
|
||||
# chosen/rejected can be either strings or lists of messages
|
||||
if isinstance(row[field], str):
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
elif isinstance(row[field], list):
|
||||
if not row[field]:
|
||||
logger.warning(f"Row {i} field '{field}' is empty list")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty list")
|
||||
else:
|
||||
logger.warning(f"Row {i} field '{field}' is neither string nor list")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
|
||||
|
||||
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
|
||||
return True
|
||||
|
||||
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
|
||||
"""Process a row in DPO format, handling both string and conversation list formats."""
|
||||
|
|
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
self.validate_dataset_format(rows)
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
|
|
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
|
|||
) -> None:
|
||||
"""Run the DPO training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
|
|
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
logger.info("Starting DPO training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
setup_multiprocessing_for_device(device)
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue