mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
feat: DistributedJobScheduler
rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
6494658a10
commit
ce48d47543
5 changed files with 534 additions and 143 deletions
|
@ -7,10 +7,12 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import functools
|
||||
import multiprocessing
|
||||
import threading
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -54,7 +56,7 @@ _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
|||
|
||||
|
||||
class Job:
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
|
||||
super().__init__()
|
||||
self.id = job_id
|
||||
self._type = job_type
|
||||
|
@ -62,9 +64,38 @@ class Job:
|
|||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
self._child_processes: list[multiprocessing.Process] = []
|
||||
self._world_size: int = 1 # Number of processes for distributed training
|
||||
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
def world_size(self) -> int:
|
||||
return self._world_size
|
||||
|
||||
@world_size.setter
|
||||
def world_size(self, size: int) -> None:
|
||||
self._world_size = size
|
||||
|
||||
def add_child_process(self, process: multiprocessing.Process) -> None:
|
||||
self._child_processes.append(process)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel the job and all its child processes."""
|
||||
for process in self._child_processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join(timeout=5)
|
||||
self.status = JobStatus.failed
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up any remaining child processes."""
|
||||
for process in self._child_processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join(timeout=5)
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler | None:
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
|
@ -111,10 +142,6 @@ class Job:
|
|||
def append_log(self, message: LogMessage) -> None:
|
||||
self._logs.append(message)
|
||||
|
||||
# TODO: implement
|
||||
def cancel(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _SchedulerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
|
@ -148,8 +175,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
# There may be performance implications of using threads due to Python
|
||||
# GIL; may need to measure if it's a real problem though
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
|
@ -158,7 +183,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
self._loop.run_forever()
|
||||
|
||||
# When stopping the loop, give tasks a chance to finish
|
||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
self._loop.run_until_complete(task)
|
||||
self._loop.close()
|
||||
|
@ -167,7 +191,6 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
# TODO: decouple scheduling and running the job
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
|
@ -179,6 +202,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
try:
|
||||
job.status = JobStatus.running
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
job.status = JobStatus.completed
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
|
@ -196,8 +220,183 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
pass
|
||||
|
||||
|
||||
class DistributedJobScheduler(_SchedulerBackend):
|
||||
"""A scheduler backend that supports distributed training jobs.
|
||||
|
||||
This scheduler uses torchrun to handle distributed training process spawning and coordination.
|
||||
torchrun automatically handles:
|
||||
- Process spawning
|
||||
- Environment variable setup
|
||||
- Process group initialization
|
||||
- Error handling and process cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
# When stopping the loop, give tasks a chance to finish
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
self._loop.run_until_complete(task)
|
||||
self._loop.close()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Clean up any remaining processes
|
||||
for process in self._active_jobs.values():
|
||||
if process.returncode is None: # Process is still running
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
async def do():
|
||||
try:
|
||||
job.status = JobStatus.running
|
||||
|
||||
# If this is a distributed training job, use torchrun
|
||||
if job.world_size > 1:
|
||||
# Find the path to finetune_handler.py
|
||||
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
|
||||
|
||||
handler_path = Path(finetune_handler.__file__)
|
||||
|
||||
# Prepare arguments for the handler script
|
||||
args = [
|
||||
"torchrun",
|
||||
f"--nproc_per_node={job.world_size}",
|
||||
"--master_addr=localhost",
|
||||
"--master_port=29500",
|
||||
str(handler_path),
|
||||
]
|
||||
|
||||
# Add arguments from the job.run_args dictionary as proper command-line flags
|
||||
for arg_name, arg_value in job.run_args.items():
|
||||
# Skip world_size as we've already handled it
|
||||
if arg_name == "world_size":
|
||||
continue
|
||||
|
||||
if arg_value is not None:
|
||||
# Handle boolean flags
|
||||
if isinstance(arg_value, bool):
|
||||
if arg_value:
|
||||
args.append(f"--{arg_name}")
|
||||
else:
|
||||
# For non-boolean values, we add the argument as a separate flag and value
|
||||
args.append(f"--{arg_name}")
|
||||
args.append(str(arg_value))
|
||||
|
||||
# Launch torchrun using asyncio
|
||||
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
|
||||
on_log_message_cb(f"Command: {' '.join(args)}")
|
||||
|
||||
# Make sure we capture stdout and stderr
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
# Store process for this job
|
||||
self._active_jobs[job.id] = process
|
||||
|
||||
# Start monitoring in a separate task so we don't block
|
||||
asyncio.create_task(
|
||||
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
|
||||
)
|
||||
else:
|
||||
# For single-device training, call the handler directly if provided
|
||||
if job.handler:
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
job.status = JobStatus.completed
|
||||
else:
|
||||
on_log_message_cb("No handler function provided for single-device training")
|
||||
job.status = JobStatus.failed
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
async def _monitor_process(
|
||||
self,
|
||||
job: Job,
|
||||
process: asyncio.subprocess.Process,
|
||||
script_path: Path | None,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
) -> None:
|
||||
"""Monitor a process until completion."""
|
||||
try:
|
||||
# Stream output from the process if stdout is available
|
||||
if process.stdout is not None:
|
||||
while True:
|
||||
line = await process.stdout.readline()
|
||||
if not line and process.returncode is not None:
|
||||
break
|
||||
if line:
|
||||
on_log_message_cb(line.decode().strip())
|
||||
else:
|
||||
# If stdout is not available, just wait for the process to complete
|
||||
on_log_message_cb("Process stdout not available, waiting for completion")
|
||||
await process.wait()
|
||||
|
||||
# Wait for process to complete if not already done
|
||||
if process.returncode is None:
|
||||
await process.wait()
|
||||
|
||||
# Check if process failed
|
||||
if process.returncode != 0:
|
||||
on_log_message_cb(f"Training failed with return code {process.returncode}")
|
||||
job.status = JobStatus.failed
|
||||
else:
|
||||
on_status_change_cb(JobStatus.completed)
|
||||
job.status = JobStatus.completed
|
||||
except Exception as e:
|
||||
on_log_message_cb(f"Error monitoring process: {str(e)}")
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Error monitoring process for job {job.id}")
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
if script_path and script_path.exists():
|
||||
script_path.unlink()
|
||||
|
||||
# Remove from active jobs
|
||||
if job.id in self._active_jobs:
|
||||
del self._active_jobs[job.id]
|
||||
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
pass
|
||||
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
pass
|
||||
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_BACKENDS = {
|
||||
"naive": _NaiveSchedulerBackend,
|
||||
"distributed": DistributedJobScheduler,
|
||||
}
|
||||
|
||||
|
||||
|
@ -230,11 +429,18 @@ class Scheduler:
|
|||
job.register_artifact(artifact)
|
||||
self._backend.on_artifact_collected_cb(job, artifact)
|
||||
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
|
||||
job = Job(type_, job_id, handler)
|
||||
if job.id in self._jobs:
|
||||
raise ValueError(f"Job {job.id} already exists")
|
||||
|
||||
# Set world size if provided
|
||||
if "world_size" in run_params:
|
||||
job.world_size = run_params["world_size"]
|
||||
|
||||
# Store all run parameters in the job's run_args dictionary
|
||||
job.run_args = run_params
|
||||
|
||||
self._jobs[job.id] = job
|
||||
job.status = JobStatus.scheduled
|
||||
self._backend.schedule(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue