mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 18:46:16 +00:00
fix the chnages that requested in review
This commit is contained in:
parent
41a45580e0
commit
518bf2fc34
6 changed files with 42 additions and 67 deletions
|
|
@ -27,6 +27,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin
|
||||||
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
|
||||||
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
| `use_reference_model` | `<class 'bool'>` | No | True | |
|
||||||
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
|
||||||
|
| `dpo_output_dir` | `<class 'str'>` | No | ./checkpoints/dpo | |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
||||||
dpo_beta: float = 0.1
|
dpo_beta: float = 0.1
|
||||||
use_reference_model: bool = True
|
use_reference_model: bool = True
|
||||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||||
|
dpo_output_dir: str = "./checkpoints/dpo"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -132,12 +132,9 @@ class HuggingFacePostTrainingImpl:
|
||||||
datasets_api=self.datasets_api,
|
datasets_api=self.datasets_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use default checkpoint directory
|
|
||||||
output_dir = f"./checkpoints/dpo/{job_uuid}"
|
|
||||||
|
|
||||||
resources_allocated, checkpoints = await recipe.train(
|
resources_allocated, checkpoints = await recipe.train(
|
||||||
model=finetuned_model,
|
model=finetuned_model,
|
||||||
output_dir=output_dir,
|
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
||||||
job_uuid=job_uuid,
|
job_uuid=job_uuid,
|
||||||
dpo_config=algorithm_config,
|
dpo_config=algorithm_config,
|
||||||
config=training_config,
|
config=training_config,
|
||||||
|
|
|
||||||
|
|
@ -8,20 +8,9 @@ import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
||||||
|
|
||||||
# Set tokenizer parallelism environment variable
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
# Force PyTorch to use OpenBLAS instead of MKL
|
|
||||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
|
||||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
|
||||||
os.environ["MKL_NUM_THREADS"] = "1"
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
|
|
@ -39,6 +28,7 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
|
@ -47,8 +37,8 @@ from ..utils import (
|
||||||
get_memory_stats,
|
get_memory_stats,
|
||||||
get_save_strategy,
|
get_save_strategy,
|
||||||
load_model,
|
load_model,
|
||||||
setup_data,
|
load_rows_from_dataset,
|
||||||
setup_multiprocessing_for_device,
|
setup_environment,
|
||||||
setup_signal_handlers,
|
setup_signal_handlers,
|
||||||
setup_torch_device,
|
setup_torch_device,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
|
|
@ -239,7 +229,7 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
if not self.validate_dataset_format(rows):
|
||||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
|
@ -383,6 +373,9 @@ class HFFinetuningSingleDevice:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the training process with signal handling."""
|
"""Run the training process with signal handling."""
|
||||||
|
|
||||||
|
# Setup environment variables
|
||||||
|
setup_environment()
|
||||||
|
|
||||||
# Setup signal handlers
|
# Setup signal handlers
|
||||||
setup_signal_handlers()
|
setup_signal_handlers()
|
||||||
|
|
||||||
|
|
@ -489,7 +482,8 @@ class HFFinetuningSingleDevice:
|
||||||
logger.info("Starting training in separate process")
|
logger.info("Starting training in separate process")
|
||||||
try:
|
try:
|
||||||
# Setup multiprocessing for device
|
# Setup multiprocessing for device
|
||||||
setup_multiprocessing_for_device(device)
|
if device.type in ["cuda", "mps"]:
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
process = multiprocessing.Process(
|
process = multiprocessing.Process(
|
||||||
target=self._run_training_sync,
|
target=self._run_training_sync,
|
||||||
|
|
|
||||||
|
|
@ -7,20 +7,9 @@
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
||||||
|
|
||||||
# Set tokenizer parallelism environment variable
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
# Force PyTorch to use OpenBLAS instead of MKL
|
|
||||||
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
|
||||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
|
||||||
os.environ["MKL_NUM_THREADS"] = "1"
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|
@ -35,6 +24,7 @@ from llama_stack.apis.post_training import (
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
|
@ -43,8 +33,8 @@ from ..utils import (
|
||||||
get_memory_stats,
|
get_memory_stats,
|
||||||
get_save_strategy,
|
get_save_strategy,
|
||||||
load_model,
|
load_model,
|
||||||
setup_data,
|
load_rows_from_dataset,
|
||||||
setup_multiprocessing_for_device,
|
setup_environment,
|
||||||
setup_signal_handlers,
|
setup_signal_handlers,
|
||||||
setup_torch_device,
|
setup_torch_device,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
|
|
@ -64,49 +54,48 @@ class HFDPOAlignmentSingleDevice:
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
self.job_uuid = job_uuid
|
self.job_uuid = job_uuid
|
||||||
|
|
||||||
def validate_dataset_format(self, rows: list[dict]) -> bool:
|
def validate_dataset_format(self, rows: list[dict]) -> None:
|
||||||
"""Validate that the dataset has the required fields for DPO training."""
|
"""Validate that the dataset has the required fields for DPO training."""
|
||||||
required_fields = ["prompt", "chosen", "rejected"]
|
required_fields = ["prompt", "chosen", "rejected"]
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
logger.warning("Dataset is empty")
|
logger.warning("Dataset is empty")
|
||||||
return False
|
raise ValueError("Dataset is empty")
|
||||||
|
|
||||||
for i, row in enumerate(rows):
|
for i, row in enumerate(rows):
|
||||||
if not isinstance(row, dict):
|
if not isinstance(row, dict):
|
||||||
logger.warning(f"Row {i} is not a dictionary")
|
logger.warning(f"Row {i} is not a dictionary")
|
||||||
return False
|
raise ValueError(f"Row {i} is not a dictionary")
|
||||||
|
|
||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
if field not in row:
|
if field not in row:
|
||||||
logger.warning(f"Row {i} missing required DPO field: {field}")
|
logger.warning(f"Row {i} missing required DPO field: {field}")
|
||||||
return False
|
raise ValueError(f"Row {i} missing required DPO field: {field}")
|
||||||
|
|
||||||
# Handle both string and list formats
|
# Handle both string and list formats
|
||||||
if field == "prompt":
|
if field == "prompt":
|
||||||
# Prompt should be a string
|
# Prompt should be a string
|
||||||
if not isinstance(row[field], str):
|
if not isinstance(row[field], str):
|
||||||
logger.warning(f"Row {i} field '{field}' is not a string")
|
logger.warning(f"Row {i} field '{field}' is not a string")
|
||||||
return False
|
raise ValueError(f"Row {i} field '{field}' is not a string")
|
||||||
if not row[field].strip():
|
if not row[field].strip():
|
||||||
logger.warning(f"Row {i} field '{field}' is empty")
|
logger.warning(f"Row {i} field '{field}' is empty")
|
||||||
return False
|
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||||
else:
|
else:
|
||||||
# chosen/rejected can be either strings or lists of messages
|
# chosen/rejected can be either strings or lists of messages
|
||||||
if isinstance(row[field], str):
|
if isinstance(row[field], str):
|
||||||
if not row[field].strip():
|
if not row[field].strip():
|
||||||
logger.warning(f"Row {i} field '{field}' is empty")
|
logger.warning(f"Row {i} field '{field}' is empty")
|
||||||
return False
|
raise ValueError(f"Row {i} field '{field}' is empty")
|
||||||
elif isinstance(row[field], list):
|
elif isinstance(row[field], list):
|
||||||
if not row[field]:
|
if not row[field]:
|
||||||
logger.warning(f"Row {i} field '{field}' is empty list")
|
logger.warning(f"Row {i} field '{field}' is empty list")
|
||||||
return False
|
raise ValueError(f"Row {i} field '{field}' is empty list")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Row {i} field '{field}' is neither string nor list")
|
logger.warning(f"Row {i} field '{field}' is neither string nor list")
|
||||||
return False
|
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
|
||||||
|
|
||||||
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
|
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
|
||||||
return True
|
|
||||||
|
|
||||||
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
|
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
|
||||||
"""Process a row in DPO format, handling both string and conversation list formats."""
|
"""Process a row in DPO format, handling both string and conversation list formats."""
|
||||||
|
|
@ -220,9 +209,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
self.validate_dataset_format(rows)
|
||||||
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
|
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
|
|
@ -348,6 +336,9 @@ class HFDPOAlignmentSingleDevice:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the DPO training process with signal handling."""
|
"""Run the DPO training process with signal handling."""
|
||||||
|
|
||||||
|
# Setup environment variables
|
||||||
|
setup_environment()
|
||||||
|
|
||||||
# Setup signal handlers
|
# Setup signal handlers
|
||||||
setup_signal_handlers()
|
setup_signal_handlers()
|
||||||
|
|
||||||
|
|
@ -457,7 +448,8 @@ class HFDPOAlignmentSingleDevice:
|
||||||
logger.info("Starting DPO training in separate process")
|
logger.info("Starting DPO training in separate process")
|
||||||
try:
|
try:
|
||||||
# Setup multiprocessing for device
|
# Setup multiprocessing for device
|
||||||
setup_multiprocessing_for_device(device)
|
if device.type in ["cuda", "mps"]:
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
process = multiprocessing.Process(
|
process = multiprocessing.Process(
|
||||||
target=self._run_training_sync,
|
target=self._run_training_sync,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -34,7 +33,7 @@ def setup_environment():
|
||||||
os.environ["MKL_NUM_THREADS"] = "1"
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def get_gb(to_convert: int) -> str:
|
def bytes_to_gb(to_convert: int) -> str:
|
||||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||||
Args:
|
Args:
|
||||||
to_convert: Memory value in bytes
|
to_convert: Memory value in bytes
|
||||||
|
|
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||||
"""Get memory statistics for the given device."""
|
"""Get memory statistics for the given device."""
|
||||||
stats = {
|
stats = {
|
||||||
"system_memory": {
|
"system_memory": {
|
||||||
"total": get_gb(psutil.virtual_memory().total),
|
"total": bytes_to_gb(psutil.virtual_memory().total),
|
||||||
"available": get_gb(psutil.virtual_memory().available),
|
"available": bytes_to_gb(psutil.virtual_memory().available),
|
||||||
"used": get_gb(psutil.virtual_memory().used),
|
"used": bytes_to_gb(psutil.virtual_memory().used),
|
||||||
"percent": psutil.virtual_memory().percent,
|
"percent": psutil.virtual_memory().percent,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
stats["device_memory"] = {
|
stats["device_memory"] = {
|
||||||
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
||||||
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
||||||
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
||||||
}
|
}
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||||
stats["device_memory"] = {
|
stats["device_memory"] = {
|
||||||
"note": "MPS memory stats not directly available",
|
"note": "MPS memory stats not directly available",
|
||||||
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
||||||
}
|
}
|
||||||
elif device.type == "cpu":
|
elif device.type == "cpu":
|
||||||
# For CPU, we track process memory usage
|
# For CPU, we track process memory usage
|
||||||
process = psutil.Process()
|
process = psutil.Process()
|
||||||
stats["device_memory"] = {
|
stats["device_memory"] = {
|
||||||
"process_rss": get_gb(process.memory_info().rss),
|
"process_rss": bytes_to_gb(process.memory_info().rss),
|
||||||
"process_vms": get_gb(process.memory_info().vms),
|
"process_vms": bytes_to_gb(process.memory_info().vms),
|
||||||
"process_percent": process.memory_percent(),
|
"process_percent": process.memory_percent(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
"""Load dataset from llama stack dataset provider"""
|
"""Load dataset from llama stack dataset provider"""
|
||||||
try:
|
try:
|
||||||
all_rows = await datasetio_api.iterrows(
|
all_rows = await datasetio_api.iterrows(
|
||||||
|
|
@ -268,12 +267,3 @@ def create_checkpoints(
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
return checkpoints
|
return checkpoints
|
||||||
|
|
||||||
|
|
||||||
def setup_multiprocessing_for_device(device: torch.device):
|
|
||||||
"""Setup multiprocessing start method based on device type.
|
|
||||||
Args:
|
|
||||||
device: The device being used for training
|
|
||||||
"""
|
|
||||||
if device.type in ["cuda", "mps"]:
|
|
||||||
multiprocessing.set_start_method("spawn", force=True)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue