feat: DistributedJobScheduler

rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse
and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-06-12 13:59:06 -04:00
parent 6494658a10
commit ce48d47543
5 changed files with 534 additions and 143 deletions

View file

@ -7,10 +7,12 @@
import abc
import asyncio
import functools
import multiprocessing
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, TypeAlias
from pydantic import BaseModel
@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
super().__init__()
self.id = job_id
self._type = job_type
@ -62,9 +64,38 @@ class Job:
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._child_processes: list[multiprocessing.Process] = []
self._world_size: int = 1 # Number of processes for distributed training
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
@property
def handler(self) -> JobHandler:
def world_size(self) -> int:
return self._world_size
@world_size.setter
def world_size(self, size: int) -> None:
self._world_size = size
def add_child_process(self, process: multiprocessing.Process) -> None:
self._child_processes.append(process)
def cancel(self) -> None:
"""Cancel the job and all its child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
self.status = JobStatus.failed
def cleanup(self) -> None:
"""Clean up any remaining child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
@property
def handler(self) -> JobHandler | None:
return self._handler
@property
@ -111,10 +142,6 @@ class Job:
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
pass
class DistributedJobScheduler(_SchedulerBackend):
"""A scheduler backend that supports distributed training jobs.
This scheduler uses torchrun to handle distributed training process spawning and coordination.
torchrun automatically handles:
- Process spawning
- Environment variable setup
- Process group initialization
- Error handling and process cleanup
"""
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
# Clean up any remaining processes
for process in self._active_jobs.values():
if process.returncode is None: # Process is still running
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
process.kill()
await process.wait()
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
# If this is a distributed training job, use torchrun
if job.world_size > 1:
# Find the path to finetune_handler.py
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
handler_path = Path(finetune_handler.__file__)
# Prepare arguments for the handler script
args = [
"torchrun",
f"--nproc_per_node={job.world_size}",
"--master_addr=localhost",
"--master_port=29500",
str(handler_path),
]
# Add arguments from the job.run_args dictionary as proper command-line flags
for arg_name, arg_value in job.run_args.items():
# Skip world_size as we've already handled it
if arg_name == "world_size":
continue
if arg_value is not None:
# Handle boolean flags
if isinstance(arg_value, bool):
if arg_value:
args.append(f"--{arg_name}")
else:
# For non-boolean values, we add the argument as a separate flag and value
args.append(f"--{arg_name}")
args.append(str(arg_value))
# Launch torchrun using asyncio
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
on_log_message_cb(f"Command: {' '.join(args)}")
# Make sure we capture stdout and stderr
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
# Store process for this job
self._active_jobs[job.id] = process
# Start monitoring in a separate task so we don't block
asyncio.create_task(
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
)
else:
# For single-device training, call the handler directly if provided
if job.handler:
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
else:
on_log_message_cb("No handler function provided for single-device training")
job.status = JobStatus.failed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
async def _monitor_process(
self,
job: Job,
process: asyncio.subprocess.Process,
script_path: Path | None,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
) -> None:
"""Monitor a process until completion."""
try:
# Stream output from the process if stdout is available
if process.stdout is not None:
while True:
line = await process.stdout.readline()
if not line and process.returncode is not None:
break
if line:
on_log_message_cb(line.decode().strip())
else:
# If stdout is not available, just wait for the process to complete
on_log_message_cb("Process stdout not available, waiting for completion")
await process.wait()
# Wait for process to complete if not already done
if process.returncode is None:
await process.wait()
# Check if process failed
if process.returncode != 0:
on_log_message_cb(f"Training failed with return code {process.returncode}")
job.status = JobStatus.failed
else:
on_status_change_cb(JobStatus.completed)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(f"Error monitoring process: {str(e)}")
job.status = JobStatus.failed
logger.exception(f"Error monitoring process for job {job.id}")
finally:
# Clean up temporary files
if script_path and script_path.exists():
script_path.unlink()
# Remove from active jobs
if job.id in self._active_jobs:
del self._active_jobs[job.id]
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
"distributed": DistributedJobScheduler,
}
@ -230,11 +429,18 @@ class Scheduler:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
# Set world size if provided
if "world_size" in run_params:
job.world_size = run_params["world_size"]
# Store all run parameters in the job's run_args dictionary
job.run_args = run_params
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(