This commit is contained in:
Charlie Doern 2025-06-18 09:20:19 -07:00 committed by GitHub
commit 6f38d12853
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1598 additions and 63 deletions

View file

@ -57,7 +57,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
weight_decay: float = 0.00
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
@ -67,6 +67,17 @@ class HuggingFacePostTrainingConfig(BaseModel):
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
# Recipe type for training (single or multi device)
recipe: str = "single"
# NCCL debug configuration for distributed training
# Enable detailed NCCL logging for debugging distributed training issues
enable_nccl_debug: bool = False
# NCCL subsystems to debug (NONE, ALL, INIT, COLL, P2P, SHM, NET)
# Controls which NCCL components generate debug output
nccl_debug_subsys: str = "NONE"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu", "recipe": "single"}

View file

@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import json
import os
from typing import Any
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.inline.post_training.huggingface.config import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import (
HFFinetuningMultiDevice,
)
from llama_stack.providers.utils.scheduler import JobStatus
async def train(
job_uuid,
model,
checkpoint_dir,
training_config,
provider_config,
algorithm_config,
data,
enable_nccl_debug=False,
nccl_debug_subsys="NONE",
):
"""Handler function for HuggingFace training that can be called by torchrun.
This is extracted from the supervised_fine_tune method in the HuggingFacePostTrainingImpl class.
It follows the same flow, but is designed to be called directly from a script.
Args:
job_uuid: Unique ID for this job
model: Model to train
checkpoint_dir: Directory to save checkpoints to
training_config: Training configuration
provider_config: Provider configuration
algorithm_config: Algorithm configuration
data: the dataset rows to be processed
enable_nccl_debug: Whether to enable NCCL debugging
nccl_debug_subsys: NCCL subsystem to debug
"""
# Get rank information when running distributed
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
parsed_data: list[dict[str, Any]] = json.loads(data)
# Set up callback functions with rank information
def on_log_message_cb(msg):
print(f"[RANK {local_rank}] {msg}", flush=True)
def on_status_change_cb(status):
print(f"[RANK {local_rank}] Status: {status}", flush=True)
def on_artifact_collected_cb(artifact):
print(f"[RANK {local_rank}] Artifact: {artifact}", flush=True)
on_log_message_cb("Starting HF finetuning")
recipe_obj = HFFinetuningMultiDevice(
job_uuid=job_uuid, enable_nccl_debug=enable_nccl_debug, nccl_debug_subsys=nccl_debug_subsys, data=parsed_data
)
resources_allocated, checkpoints = await recipe_obj.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=provider_config,
)
def resources_stats_to_artifact(resources_stats):
return {
"type": "resources_stats",
"name": "resources_stats",
"metadata": resources_stats,
}
def checkpoint_to_artifact(checkpoint):
return {
"type": "checkpoint",
"name": checkpoint.identifier,
"uri": checkpoint.path,
"metadata": dict(checkpoint),
}
on_artifact_collected_cb(resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(JobStatus.completed)
on_log_message_cb("HF finetuning completed")
async def main():
parser = argparse.ArgumentParser(description="Run HuggingFace training with torchrun.")
parser.add_argument("--job_uuid", type=str, required=True, help="Job UUID")
parser.add_argument("--model", type=str, required=True, help="Model to use")
parser.add_argument("--checkpoint_dir", type=str, help="Directory to save checkpoints")
parser.add_argument("--training_config", type=str, required=True, help="Training config JSON")
parser.add_argument("--provider_config", type=str, required=True, help="Provider config JSON")
parser.add_argument("--algorithm_config", type=str, help="Algorithm config JSON")
parser.add_argument("--enable_nccl_debug", action="store_true", help="Enable NCCL debugging")
parser.add_argument("--nccl_debug_subsys", type=str, default="NONE", help="NCCL subsystem to debug")
parser.add_argument("--data", type=str, required=True)
args = parser.parse_args()
# Parse JSON configs
try:
training_config = TrainingConfig.model_validate_json(args.training_config)
except Exception as e:
print(f"Error parsing training_config: {e}")
print(f"Received: {args.training_config}")
raise
try:
provider_config = HuggingFacePostTrainingConfig.model_validate_json(args.provider_config)
except Exception as e:
print(f"Error parsing provider_config: {e}")
print(f"Received: {args.provider_config}")
raise
algorithm_config = None
if args.algorithm_config:
try:
algorithm_config = json.loads(args.algorithm_config)
except json.JSONDecodeError as e:
print(f"Error parsing algorithm_config: {e}")
print(f"Received: {args.algorithm_config}")
raise
# In a real implementation, you would get these from somewhere
# For now, we'll pass None and handle it in the train function
datasetio_api = None
datasets_api = None
# Print arguments for debugging
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if local_rank == 0: # Only the main process prints
print("Starting training with arguments:")
print(f" job_uuid: {args.job_uuid}")
print(f" model: {args.model}")
print(f" checkpoint_dir: {args.checkpoint_dir}")
print(f" enable_nccl_debug: {args.enable_nccl_debug}")
print(f" nccl_debug_subsys: {args.nccl_debug_subsys}")
await train(
job_uuid=args.job_uuid,
model=args.model,
checkpoint_dir=args.checkpoint_dir,
training_config=training_config,
provider_config=provider_config,
algorithm_config=algorithm_config,
datasetio_api=datasetio_api,
datasets_api=datasets_api,
enable_nccl_debug=args.enable_nccl_debug,
nccl_debug_subsys=args.nccl_debug_subsys,
data=args.data,
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from enum import Enum
from typing import Any
@ -22,6 +23,7 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_multi_device import HFFinetuningMultiDevice
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
@ -80,6 +82,61 @@ class HuggingFacePostTrainingImpl:
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
from collections.abc import Callable, Coroutine
from typing import Any
# Type for the handler: async fn taking 3 Any args, returns Awaitable[None]
handler: (
Callable[
[Callable[[str], None], Callable[[SchedulerJobStatus], None], Callable[[JobArtifact], None]],
Coroutine[Any, Any, None],
]
| None
) = None
# Determine world size for distributed training
world_size = getattr(self.config, "world_size", 1)
# Choose the backend and recipe based on world size
if world_size > 1:
recipe = "multi"
# Create parameters for the handler script
run_params = {
"job_uuid": job_uuid,
"model": model,
"world_size": world_size,
"recipe": recipe,
}
# Add optional parameters
if checkpoint_dir is not None:
run_params["checkpoint_dir"] = checkpoint_dir
if training_config is not None:
run_params["training_config"] = training_config.model_dump_json()
if algorithm_config is not None:
run_params["algorithm_config"] = algorithm_config.model_dump_json()
# Add provider-specific configuration
run_params["provider_config"] = self.config.model_dump_json()
# Add NCCL debug settings if present
if hasattr(self.config, "enable_nccl_debug"):
run_params["enable_nccl_debug"] = self.config.enable_nccl_debug
if hasattr(self.config, "nccl_debug_subsys"):
run_params["nccl_debug_subsys"] = self.config.nccl_debug_subsys
# Initialize the scheduler with the distributed backend
self._scheduler = Scheduler(backend="distributed")
else:
self._scheduler = Scheduler(backend="naive")
# TODO: this can probably be cleaner
# Single-device training path
# Define a handler for single-device training
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
@ -88,6 +145,14 @@ class HuggingFacePostTrainingImpl:
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
if self.config.recipe == "multi":
recipe = HFFinetuningMultiDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
enable_nccl_debug=getattr(self.config, "enable_nccl_debug", False),
nccl_debug_subsys=getattr(self.config, "nccl_debug_subsys", "NONE"),
)
resources_allocated, checkpoints = await recipe.train(
model=model,
@ -107,8 +172,40 @@ class HuggingFacePostTrainingImpl:
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
assert training_config.data_config is not None
data = self._setup_data(dataset_id=training_config.data_config.dataset_id)
json_data = json.dumps(data)
run_params["data"] = json_data
# Schedule the job with the regular scheduler and the handler
job_id = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler, run_params)
return PostTrainingJob(job_uuid=job_id)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider.
Args:
dataset_id: ID of the dataset to load
Returns:
list: List of dataset rows
Raises:
RuntimeError: If dataset loading fails
"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
async def preference_optimize(
self,

View file

@ -37,8 +37,6 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
@ -136,11 +134,9 @@ class HFFinetuningSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
data: list[dict[str, Any]],
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.data = data
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> bool:
@ -262,19 +258,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,10 +310,9 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
if not self.validate_dataset_format(self.data):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
logger.info(f"Loaded {len(self.data)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
@ -362,7 +344,7 @@ class HFFinetuningSingleDevice:
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._create_dataset(self.data, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:

View file

@ -7,10 +7,12 @@
import abc
import asyncio
import functools
import multiprocessing
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, TypeAlias
from pydantic import BaseModel
@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
super().__init__()
self.id = job_id
self._type = job_type
@ -62,9 +64,38 @@ class Job:
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._child_processes: list[multiprocessing.Process] = []
self._world_size: int = 1 # Number of processes for distributed training
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
@property
def handler(self) -> JobHandler:
def world_size(self) -> int:
return self._world_size
@world_size.setter
def world_size(self, size: int) -> None:
self._world_size = size
def add_child_process(self, process: multiprocessing.Process) -> None:
self._child_processes.append(process)
def cancel(self) -> None:
"""Cancel the job and all its child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
self.status = JobStatus.failed
def cleanup(self) -> None:
"""Clean up any remaining child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
@property
def handler(self) -> JobHandler | None:
return self._handler
@property
@ -111,10 +142,6 @@ class Job:
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
pass
class DistributedJobScheduler(_SchedulerBackend):
"""A scheduler backend that supports distributed training jobs.
This scheduler uses torchrun to handle distributed training process spawning and coordination.
torchrun automatically handles:
- Process spawning
- Environment variable setup
- Process group initialization
- Error handling and process cleanup
"""
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
# Clean up any remaining processes
for process in self._active_jobs.values():
if process.returncode is None: # Process is still running
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
process.kill()
await process.wait()
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
# If this is a distributed training job, use torchrun
if job.world_size > 1:
# Find the path to finetune_handler.py
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
handler_path = Path(finetune_handler.__file__)
# Prepare arguments for the handler script
args = [
"torchrun",
f"--nproc_per_node={job.world_size}",
"--master_addr=localhost",
"--master_port=29500",
str(handler_path),
]
# Add arguments from the job.run_args dictionary as proper command-line flags
for arg_name, arg_value in job.run_args.items():
# Skip world_size as we've already handled it
if arg_name == "world_size":
continue
if arg_value is not None:
# Handle boolean flags
if isinstance(arg_value, bool):
if arg_value:
args.append(f"--{arg_name}")
else:
# For non-boolean values, we add the argument as a separate flag and value
args.append(f"--{arg_name}")
args.append(str(arg_value))
# Launch torchrun using asyncio
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
on_log_message_cb(f"Command: {' '.join(args)}")
# Make sure we capture stdout and stderr
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
# Store process for this job
self._active_jobs[job.id] = process
# Start monitoring in a separate task so we don't block
asyncio.create_task(
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
)
else:
# For single-device training, call the handler directly if provided
if job.handler:
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
else:
on_log_message_cb("No handler function provided for single-device training")
job.status = JobStatus.failed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
async def _monitor_process(
self,
job: Job,
process: asyncio.subprocess.Process,
script_path: Path | None,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
) -> None:
"""Monitor a process until completion."""
try:
# Stream output from the process if stdout is available
if process.stdout is not None:
while True:
line = await process.stdout.readline()
if not line and process.returncode is not None:
break
if line:
on_log_message_cb(line.decode().strip())
else:
# If stdout is not available, just wait for the process to complete
on_log_message_cb("Process stdout not available, waiting for completion")
await process.wait()
# Wait for process to complete if not already done
if process.returncode is None:
await process.wait()
# Check if process failed
if process.returncode != 0:
on_log_message_cb(f"Training failed with return code {process.returncode}")
job.status = JobStatus.failed
else:
on_status_change_cb(JobStatus.completed)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(f"Error monitoring process: {str(e)}")
job.status = JobStatus.failed
logger.exception(f"Error monitoring process for job {job.id}")
finally:
# Clean up temporary files
if script_path and script_path.exists():
script_path.unlink()
# Remove from active jobs
if job.id in self._active_jobs:
del self._active_jobs[job.id]
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
"distributed": DistributedJobScheduler,
}
@ -230,11 +429,18 @@ class Scheduler:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
# Set world size if provided
if "world_size" in run_params:
job.world_size = run_params["world_size"]
# Store all run parameters in the job's run_args dictionary
job.run_args = run_params
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(

View file

@ -100,6 +100,7 @@ providers:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
recipe: single
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -98,6 +98,7 @@ providers:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
recipe: single
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search