feat: DistributedJobScheduler

rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse
and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-06-12 13:59:06 -04:00
parent 6494658a10
commit ce48d47543
5 changed files with 534 additions and 143 deletions

View file

@ -75,8 +75,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -191,8 +189,7 @@ class HFFinetuningMultiDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
enable_nccl_debug: bool = False,
nccl_debug_subsys: str = "NONE",
):
@ -203,8 +200,7 @@ class HFFinetuningMultiDevice:
datasetio_api: API for dataset I/O operations
datasets_api: API for dataset management
"""
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
self.enable_nccl_debug = enable_nccl_debug
self.nccl_debug_subsys = nccl_debug_subsys
@ -408,29 +404,6 @@ class HFFinetuningMultiDevice:
num_proc=1, # Single process to avoid issues
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
local_rank: int, # First parameter must be local_rank for spawn
@ -627,10 +600,9 @@ class HFFinetuningMultiDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -662,7 +634,7 @@ class HFFinetuningMultiDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
@ -1021,37 +993,16 @@ class HFFinetuningMultiDevice:
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer with distributed training.
The distributed training setup works as follows:
1. Parse the device list to determine number of GPUs
2. Use torch.multiprocessing.spawn to launch one process per GPU
3. Each process runs _run_training_sync with a unique rank
4. The processes coordinate through NCCL backend
5. FSDP handles model sharding across GPUs
6. Only rank 0 handles saving checkpoints and logging
Args:
model: The model identifier to load
output_dir: Optional directory to save checkpoints
job_uuid: Unique identifier for this training job
lora_config: LoRA configuration for parameter-efficient fine-tuning
config: General training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (memory_stats, checkpoints)
"""
"""Train a model using HuggingFace's SFTTrainer with distributed training."""
if provider_config.distributed_backend != "fsdp":
raise RuntimeError("Must enable FSDP as distributed backend to use this recipe")
# Configure NCCL logging based on debug settings
configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys)
# Parse device list to determine number of GPUs
devices = [d.strip() for d in provider_config.device.split(",")]
world_size = len(devices)
logger.info(f"Using {world_size} devices: {devices}")
# Get local rank and world size from environment variables
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
output_dir_path = None
if output_dir:
@ -1081,32 +1032,22 @@ class HFFinetuningMultiDevice:
raise ValueError("DataConfig is required for training")
try:
# Launch distributed training processes
# torch.multiprocessing.spawn will:
# 1. Create world_size number of processes
# 2. Call _run_training_sync for each process
# 3. Pass unique local_rank to each process
# 4. Handle process coordination and cleanup
logger.info("Starting distributed training processes")
torch.multiprocessing.spawn(
self._run_training_sync,
args=(
world_size,
model,
provider_config.model_dump(),
peft_config,
config.model_dump(),
output_dir_path,
),
nprocs=world_size,
join=True, # Wait for all processes to complete
# Run training for this process
await self._run_training(
model=model,
provider_config=provider_config.model_dump(),
peft_config=peft_config,
config=config.model_dump(),
output_dir_path=output_dir_path,
local_rank=local_rank,
world_size=world_size,
)
memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0"))
# Create checkpoint on rank 0
# Only create checkpoint on rank 0
checkpoints = None
if output_dir_path:
if output_dir_path and local_rank == 0:
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),

View file

@ -37,8 +37,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e: