llama-stack-mirror/llama_stack/providers/utils/scheduler.py
Charlie Doern ce48d47543 feat: DistributedJobScheduler
rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse
and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-06-12 16:05:45 -04:00

472 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import multiprocessing
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._child_processes: list[multiprocessing.Process] = []
self._world_size: int = 1 # Number of processes for distributed training
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
@property
def world_size(self) -> int:
return self._world_size
@world_size.setter
def world_size(self, size: int) -> None:
self._world_size = size
def add_child_process(self, process: multiprocessing.Process) -> None:
self._child_processes.append(process)
def cancel(self) -> None:
"""Cancel the job and all its child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
self.status = JobStatus.failed
def cleanup(self) -> None:
"""Clean up any remaining child processes."""
for process in self._child_processes:
if process.is_alive():
process.terminate()
process.join(timeout=5)
@property
def handler(self) -> JobHandler | None:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
class DistributedJobScheduler(_SchedulerBackend):
"""A scheduler backend that supports distributed training jobs.
This scheduler uses torchrun to handle distributed training process spawning and coordination.
torchrun automatically handles:
- Process spawning
- Environment variable setup
- Process group initialization
- Error handling and process cleanup
"""
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
# Clean up any remaining processes
for process in self._active_jobs.values():
if process.returncode is None: # Process is still running
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
process.kill()
await process.wait()
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
# If this is a distributed training job, use torchrun
if job.world_size > 1:
# Find the path to finetune_handler.py
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
handler_path = Path(finetune_handler.__file__)
# Prepare arguments for the handler script
args = [
"torchrun",
f"--nproc_per_node={job.world_size}",
"--master_addr=localhost",
"--master_port=29500",
str(handler_path),
]
# Add arguments from the job.run_args dictionary as proper command-line flags
for arg_name, arg_value in job.run_args.items():
# Skip world_size as we've already handled it
if arg_name == "world_size":
continue
if arg_value is not None:
# Handle boolean flags
if isinstance(arg_value, bool):
if arg_value:
args.append(f"--{arg_name}")
else:
# For non-boolean values, we add the argument as a separate flag and value
args.append(f"--{arg_name}")
args.append(str(arg_value))
# Launch torchrun using asyncio
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
on_log_message_cb(f"Command: {' '.join(args)}")
# Make sure we capture stdout and stderr
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
# Store process for this job
self._active_jobs[job.id] = process
# Start monitoring in a separate task so we don't block
asyncio.create_task(
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
)
else:
# For single-device training, call the handler directly if provided
if job.handler:
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
job.status = JobStatus.completed
else:
on_log_message_cb("No handler function provided for single-device training")
job.status = JobStatus.failed
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
async def _monitor_process(
self,
job: Job,
process: asyncio.subprocess.Process,
script_path: Path | None,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
) -> None:
"""Monitor a process until completion."""
try:
# Stream output from the process if stdout is available
if process.stdout is not None:
while True:
line = await process.stdout.readline()
if not line and process.returncode is not None:
break
if line:
on_log_message_cb(line.decode().strip())
else:
# If stdout is not available, just wait for the process to complete
on_log_message_cb("Process stdout not available, waiting for completion")
await process.wait()
# Wait for process to complete if not already done
if process.returncode is None:
await process.wait()
# Check if process failed
if process.returncode != 0:
on_log_message_cb(f"Training failed with return code {process.returncode}")
job.status = JobStatus.failed
else:
on_status_change_cb(JobStatus.completed)
job.status = JobStatus.completed
except Exception as e:
on_log_message_cb(f"Error monitoring process: {str(e)}")
job.status = JobStatus.failed
logger.exception(f"Error monitoring process for job {job.id}")
finally:
# Clean up temporary files
if script_path and script_path.exists():
script_path.unlink()
# Remove from active jobs
if job.id in self._active_jobs:
del self._active_jobs[job.id]
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
"distributed": DistributedJobScheduler,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
# Set world size if provided
if "world_size" in run_params:
job.world_size = run_params["world_size"]
# Store all run parameters in the job's run_args dictionary
job.run_args = run_params
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()