mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 18:46:16 +00:00
fix the chnages that requested in review
This commit is contained in:
parent
41a45580e0
commit
518bf2fc34
6 changed files with 42 additions and 67 deletions
|
|
@ -27,6 +27,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
|
|||
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||
| `dpo_output_dir` | `<class 'str'>` | No | ./checkpoints/dpo | |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
|||
dpo_beta: float = 0.1
|
||||
use_reference_model: bool = True
|
||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
dpo_output_dir: str = "./checkpoints/dpo"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -132,12 +132,9 @@ class HuggingFacePostTrainingImpl:
|
|||
datasets_api=self.datasets_api,
|
||||
)
|
||||
|
||||
# Use default checkpoint directory
|
||||
output_dir = f"./checkpoints/dpo/{job_uuid}"
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=finetuned_model,
|
||||
output_dir=output_dir,
|
||||
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
||||
job_uuid=job_uuid,
|
||||
dpo_config=algorithm_config,
|
||||
config=training_config,
|
||||
|
|
|
|||
|
|
@ -8,20 +8,9 @@ import gc
|
|||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
# Set tokenizer parallelism environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Force PyTorch to use OpenBLAS instead of MKL
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig
|
||||
|
|
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
@ -47,8 +37,8 @@ from ..utils import (
|
|||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
setup_data,
|
||||
setup_multiprocessing_for_device,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
|
|
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
|
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
|
|||
) -> None:
|
||||
"""Run the training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
|
|
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
|
|||
logger.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
setup_multiprocessing_for_device(device)
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
|
|
|
|||
|
|
@ -7,20 +7,9 @@
|
|||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
# Set tokenizer parallelism environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Force PyTorch to use OpenBLAS instead of MKL
|
||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
|
|
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
@ -43,8 +33,8 @@ from ..utils import (
|
|||
get_memory_stats,
|
||||
get_save_strategy,
|
||||
load_model,
|
||||
setup_data,
|
||||
setup_multiprocessing_for_device,
|
||||
load_rows_from_dataset,
|
||||
setup_environment,
|
||||
setup_signal_handlers,
|
||||
setup_torch_device,
|
||||
split_dataset,
|
||||
|
|
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
self.job_uuid = job_uuid
|
||||
|
||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
||||
def validate_dataset_format(self, rows: list[dict]) -> None:
|
||||
"""Validate that the dataset has the required fields for DPO training."""
|
||||
required_fields = ["prompt", "chosen", "rejected"]
|
||||
|
||||
if not rows:
|
||||
logger.warning("Dataset is empty")
|
||||
return False
|
||||
raise ValueError("Dataset is empty")
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
if not isinstance(row, dict):
|
||||
logger.warning(f"Row {i} is not a dictionary")
|
||||
return False
|
||||
raise ValueError(f"Row {i} is not a dictionary")
|
||||
|
||||
for field in required_fields:
|
||||
if field not in row:
|
||||
logger.warning(f"Row {i} missing required DPO field: {field}")
|
||||
return False
|
||||
raise ValueError(f"Row {i} missing required DPO field: {field}")
|
||||
|
||||
# Handle both string and list formats
|
||||
if field == "prompt":
|
||||
# Prompt should be a string
|
||||
if not isinstance(row[field], str):
|
||||
logger.warning(f"Row {i} field '{field}' is not a string")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is not a string")
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
else:
|
||||
# chosen/rejected can be either strings or lists of messages
|
||||
if isinstance(row[field], str):
|
||||
if not row[field].strip():
|
||||
logger.warning(f"Row {i} field '{field}' is empty")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||
elif isinstance(row[field], list):
|
||||
if not row[field]:
|
||||
logger.warning(f"Row {i} field '{field}' is empty list")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is empty list")
|
||||
else:
|
||||
logger.warning(f"Row {i} field '{field}' is neither string nor list")
|
||||
return False
|
||||
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
|
||||
|
||||
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
|
||||
return True
|
||||
|
||||
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
|
||||
"""Process a row in DPO format, handling both string and conversation list formats."""
|
||||
|
|
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
self.validate_dataset_format(rows)
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
|
|
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
|
|||
) -> None:
|
||||
"""Run the DPO training process with signal handling."""
|
||||
|
||||
# Setup environment variables
|
||||
setup_environment()
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers()
|
||||
|
||||
|
|
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
|
|||
logger.info("Starting DPO training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
setup_multiprocessing_for_device(device)
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=self._run_training_sync,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -34,7 +33,7 @@ def setup_environment():
|
|||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
|
||||
def get_gb(to_convert: int) -> str:
|
||||
def bytes_to_gb(to_convert: int) -> str:
|
||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||
Args:
|
||||
to_convert: Memory value in bytes
|
||||
|
|
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
|||
"""Get memory statistics for the given device."""
|
||||
stats = {
|
||||
"system_memory": {
|
||||
"total": get_gb(psutil.virtual_memory().total),
|
||||
"available": get_gb(psutil.virtual_memory().available),
|
||||
"used": get_gb(psutil.virtual_memory().used),
|
||||
"total": bytes_to_gb(psutil.virtual_memory().total),
|
||||
"available": bytes_to_gb(psutil.virtual_memory().available),
|
||||
"used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
"percent": psutil.virtual_memory().percent,
|
||||
}
|
||||
}
|
||||
|
||||
if device.type == "cuda":
|
||||
stats["device_memory"] = {
|
||||
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
||||
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
||||
}
|
||||
elif device.type == "mps":
|
||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||
stats["device_memory"] = {
|
||||
"note": "MPS memory stats not directly available",
|
||||
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
||||
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
}
|
||||
elif device.type == "cpu":
|
||||
# For CPU, we track process memory usage
|
||||
process = psutil.Process()
|
||||
stats["device_memory"] = {
|
||||
"process_rss": get_gb(process.memory_info().rss),
|
||||
"process_vms": get_gb(process.memory_info().vms),
|
||||
"process_rss": bytes_to_gb(process.memory_info().rss),
|
||||
"process_vms": bytes_to_gb(process.memory_info().vms),
|
||||
"process_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|||
return device
|
||||
|
||||
|
||||
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await datasetio_api.iterrows(
|
||||
|
|
@ -268,12 +267,3 @@ def create_checkpoints(
|
|||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def setup_multiprocessing_for_device(device: torch.device):
|
||||
"""Setup multiprocessing start method based on device type.
|
||||
Args:
|
||||
device: The device being used for training
|
||||
"""
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue