feat: DistributedJobScheduler

rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse
and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-06-12 13:59:06 -04:00
parent 6494658a10
commit ce48d47543
5 changed files with 534 additions and 143 deletions

View file

@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
from typing import Any
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import (
HFFinetuningMultiDevice,
)
from llama_stack.providers.utils.scheduler import JobStatus
async def train(
job_uuid,
model,
checkpoint_dir,
training_config,
provider_config,
algorithm_config,
data,
enable_nccl_debug=False,
nccl_debug_subsys="NONE",
):
"""Handler function for HuggingFace training that can be called by torchrun.
This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class.
It follows the same flow, but is designed to be called directly from a script.
Args:
job_uuid: Unique ID for this job
model: Model to train
checkpoint_dir: Directory to save checkpoints to
training_config: Training configuration
provider_config: Provider configuration
algorithm_config: Algorithm configuration
data: the dataset rows to be processed
enable_nccl_debug: Whether to enable NCCL debugging
nccl_debug_subsys: NCCL subsystem to debug
"""
# Get rank information when running distributed
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
parsed_data: list[dict[str, Any]] = json.loads(data)
# Set up callback functions with rank information
def on_log_message_cb(msg):
print(f"[RANK {local_rank}] {msg}", flush=True)
def on_status_change_cb(status):
print(f"[RANK {local_rank}] Status: {status}", flush=True)
def on_artifact_collected_cb(artifact):
print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True)
on_log_message_cb("Starting HF finetuning")
recipe_obj = HFFinetuningMultiDevice(
job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data
)
resources_allocated, checkpoints = await recipe_obj.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=provider_config,
)
def resources_stats_to_artifact(resources_stats):
return {
"type": "resources_stats",
"name": "resources_stats",
"metadata": resources_stats,
}
def checkpoint_to_artifact(checkpoint):
return {
"type": "checkpoint",
"name": checkpoint.identifier,
"uri": checkpoint.path,
"metadata": dict(checkpoint),
}
on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(JobStatus.completed)
on_log_message_cb("HF finetuning completed")
async def main():
parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.")
parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID")
parser.add_argument("--model", type=str, required=True, help="Model to use")
parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints")
parser.add_argument("--training_config", type=str, required=True, help="Training config JSON")
parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON")
parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON")
parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging")
parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug")
parser.add_argument("--data", type=str, required=True)
args = parser.parse_args()
# Parse JSON configs
try:
training_config = TrainingConfig.model_validate_json(args.training_config)
except Exception as e:
print(f"Error parsing training_config: {e}")
print(f"Received: {args.training_config}")
raise
try:
provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config)
except Exception as e:
print(f"Error parsing provider_config: {e}")
print(f"Received: {args.provider_config}")
raise
algorithm_config = None
if args.algorithm_config:
try:
algorithm_config = json.loads(args.algorithm_config)
except json.JSONDecodeError as e:
print(f"Error parsing algorithm_config: {e}")
print(f"Received: {args.algorithm_config}")
raise
# In a real implementation, you would get these from somewhere
# For now, we'll pass None and handle it in the train function
datasetio_api = None
datasets_api = None
# Print arguments for debugging
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if local_rank == 0: # Only the main process prints
print("Starting training with arguments:")
print(f" job_uuid: {args.job_uuid}")
print(f" model: {args.model}")
print(f" checkpoint_dir: {args.checkpoint_dir}")
print(f" enable_nccl_debug: {args.enable_nccl_debug}")
print(f" nccl_debug_subsys: {args.nccl_debug_subsys}")
await train(
job_uuid=args.job_uuid,
model=args.model,
checkpoint_dir=args.checkpoint_dir,
training_config=training_config,
provider_config=provider_config,
algorithm_config=algorithm_config,
datasetio_api=datasetio_api,
datasets_api=datasets_api,
enable_nccl_debug=args.enable_nccl_debug,
nccl_debug_subsys=args.nccl_debug_subsys,
data=args.data,
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from enum import Enum
from typing import Any
@ -81,43 +82,130 @@ class HuggingFacePostTrainingImpl:
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
from collections.abc import Callable, Coroutine
from typing import Any
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
if self.config.recipe == "multi":
recipe = HFFinetuningMultiDevice(
# Type for the handler: async fn taking 3 Any args, returns Awaitable[None]
handler: (
Callable[
[Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]],
Coroutine[Any, Any, None],
]
| None
) = None
# Determine world size for distributed training
world_size = getattr(self.config, "world_size", 1)
# Choose the backend and recipe based on world size
if world_size > 1:
recipe = "multi"
# Create parameters for the handler script
run_params = {
"job_uuid": job_uuid,
"model": model,
"world_size": world_size,
"recipe": recipe,
}
# Add optional parameters
if checkpoint_dir is not None:
run_params["checkpoint_dir"] = checkpoint_dir
if training_config is not None:
run_params["training_config"] = training_config.model_dump_json()
if algorithm_config is not None:
run_params["algorithm_config"] = algorithm_config.model_dump_json()
# Add provider-specific configuration
run_params["provider_config"] = self.config.model_dump_json()
# Add NCCL debug settings if present
if hasattr(self.config, "enable_nccl_debug"):
run_params["enable_nccl_debug"] = self.config.enable_nccl_debug
if hasattr(self.config, "nccl_debug_subsys"):
run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys
# Initialize the scheduler with the distributed backend
self._scheduler = Scheduler(backend="distributed")
else:
self._scheduler = Scheduler(backend="naive")
# TODO: this can probably be cleaner
# Single-device training path
# Define a handler for single-device training
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
enable_nccl_debug=self.config.enable_nccl_debug,
nccl_debug_subsys=self.config.nccl_debug_subsys,
)
if self.config.recipe == "multi":
recipe = HFFinetuningMultiDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False),
nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"),
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
assert training_config.data_config is not None
data = self._setup_data(dataset_id=training_config.data_config.dataset_id)
json_data = json.dumps(data)
run_params["data"] = json_data
# Schedule the job with the regular scheduler and the handler
job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params)
return PostTrainingJob(job_uuid=job_id)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
async def preference_optimize(
self,

View file

@ -75,8 +75,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -191,8 +189,7 @@ class HFFinetuningMultiDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
enable_nccl_debug: bool = False,
nccl_debug_subsys: str = "NONE",
):
@ -203,8 +200,7 @@ class HFFinetuningMultiDevice:
datasetio_api: API for dataset I/O operations
datasets_api: API for dataset management
"""
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
self.enable_nccl_debug = enable_nccl_debug
self.nccl_debug_subsys = nccl_debug_subsys
@ -408,29 +404,6 @@ class HFFinetuningMultiDevice:
num_proc=1, # Single process to avoid issues
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
local_rank: int, # First parameter must be local_rank for spawn
@ -627,10 +600,9 @@ class HFFinetuningMultiDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -662,7 +634,7 @@ class HFFinetuningMultiDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
@ -1021,37 +993,16 @@ class HFFinetuningMultiDevice:
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's SFTTrainer with distributed training.
The distributed training setup works as follows:
1. Parse the device list to determine number of GPUs
2. Use torch.multiprocessing.spawn to launch one process per GPU
3. Each process runs _run_training_sync with a unique rank
4. The processes coordinate through NCCL backend
5. FSDP handles model sharding across GPUs
6. Only rank 0 handles saving checkpoints and logging
Args:
model: The model identifier to load
output_dir: Optional directory to save checkpoints
job_uuid: Unique identifier for this training job
lora_config: LoRA configuration for parameter-efficient fine-tuning
config: General training configuration
provider_config: Provider-specific configuration
Returns:
tuple: (memory_stats, checkpoints)
"""
"""Train a model using HuggingFace's SFTTrainer with distributed training."""
if provider_config.distributed_backend != "fsdp":
raise RuntimeError("Must enable FSDP as distributed backend to use this recipe")
# Configure NCCL logging based on debug settings
configure_nccl_logging(self.enable_nccl_debug, self.nccl_debug_subsys)
# Parse device list to determine number of GPUs
devices = [d.strip() for d in provider_config.device.split(",")]
world_size = len(devices)
logger.info(f"Using {world_size} devices: {devices}")
# Get local rank and world size from environment variables
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
output_dir_path = None
if output_dir:
@ -1081,32 +1032,22 @@ class HFFinetuningMultiDevice:
raise ValueError("DataConfig is required for training")
try:
# Launch distributed training processes
# torch.multiprocessing.spawn will:
# 1. Create world_size number of processes
# 2. Call _run_training_sync for each process
# 3. Pass unique local_rank to each process
# 4. Handle process coordination and cleanup
logger.info("Starting distributed training processes")
torch.multiprocessing.spawn(
self._run_training_sync,
args=(
world_size,
model,
provider_config.model_dump(),
peft_config,
config.model_dump(),
output_dir_path,
),
nprocs=world_size,
join=True, # Wait for all processes to complete
# Run training for this process
await self._run_training(
model=model,
provider_config=provider_config.model_dump(),
peft_config=peft_config,
config=config.model_dump(),
output_dir_path=output_dir_path,
local_rank=local_rank,
world_size=world_size,
)
memory_stats["after_training"] = get_memory_stats(torch.device("cuda:0"))
# Create checkpoint on rank 0
# Only create checkpoint on rank 0
checkpoints = None
if output_dir_path:
if output_dir_path and local_rank == 0:
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),

View file

@ -37,8 +37,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e: