mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
rather than handling multi-GPU training within a recipe, distributed training should be one of our scheduler offerings. Introduce the DistributedJobScheduler which kicks off a `finetune_handler.py` script using torchrun. This handler processes the training args via argparse and calls the right recipe as `post_training.py` used to do. Torchrun takes care of env variables like world_size, local_rank, etc. Signed-off-by: Charlie Doern <cdoern@redhat.com>
472 lines
16 KiB
Python
472 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import abc
|
|
import asyncio
|
|
import functools
|
|
import multiprocessing
|
|
import threading
|
|
from collections.abc import Callable, Coroutine, Iterable
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, TypeAlias
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.log import get_logger
|
|
|
|
logger = get_logger(name=__name__, category="scheduler")
|
|
|
|
|
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
|
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
|
class JobStatus(Enum):
|
|
new = "new"
|
|
scheduled = "scheduled"
|
|
running = "running"
|
|
failed = "failed"
|
|
completed = "completed"
|
|
|
|
|
|
JobID: TypeAlias = str
|
|
JobType: TypeAlias = str
|
|
|
|
|
|
class JobArtifact(BaseModel):
|
|
type: JobType
|
|
name: str
|
|
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
|
uri: str | None = None
|
|
metadata: dict[str, Any]
|
|
|
|
|
|
JobHandler = Callable[
|
|
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
|
]
|
|
|
|
|
|
LogMessage: TypeAlias = tuple[datetime, str]
|
|
|
|
|
|
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
|
|
|
|
|
class Job:
|
|
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler | None):
|
|
super().__init__()
|
|
self.id = job_id
|
|
self._type = job_type
|
|
self._handler = handler
|
|
self._artifacts: list[JobArtifact] = []
|
|
self._logs: list[LogMessage] = []
|
|
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
|
self._child_processes: list[multiprocessing.Process] = []
|
|
self._world_size: int = 1 # Number of processes for distributed training
|
|
self.run_args: dict[str, Any] = {} # Dictionary to store run arguments
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return self._world_size
|
|
|
|
@world_size.setter
|
|
def world_size(self, size: int) -> None:
|
|
self._world_size = size
|
|
|
|
def add_child_process(self, process: multiprocessing.Process) -> None:
|
|
self._child_processes.append(process)
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel the job and all its child processes."""
|
|
for process in self._child_processes:
|
|
if process.is_alive():
|
|
process.terminate()
|
|
process.join(timeout=5)
|
|
self.status = JobStatus.failed
|
|
|
|
def cleanup(self) -> None:
|
|
"""Clean up any remaining child processes."""
|
|
for process in self._child_processes:
|
|
if process.is_alive():
|
|
process.terminate()
|
|
process.join(timeout=5)
|
|
|
|
@property
|
|
def handler(self) -> JobHandler | None:
|
|
return self._handler
|
|
|
|
@property
|
|
def status(self) -> JobStatus:
|
|
return self._state_transitions[-1][1]
|
|
|
|
@status.setter
|
|
def status(self, status: JobStatus):
|
|
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
|
raise ValueError(f"Job is already in a completed state ({self.status})")
|
|
if self.status == status:
|
|
return
|
|
self._state_transitions.append((datetime.now(timezone.utc), status))
|
|
|
|
@property
|
|
def artifacts(self) -> list[JobArtifact]:
|
|
return self._artifacts
|
|
|
|
def register_artifact(self, artifact: JobArtifact) -> None:
|
|
self._artifacts.append(artifact)
|
|
|
|
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
|
for date, s in reversed(self._state_transitions):
|
|
if s in status:
|
|
return date
|
|
return None
|
|
|
|
@property
|
|
def scheduled_at(self) -> datetime | None:
|
|
return self._find_state_transition_date([JobStatus.scheduled])
|
|
|
|
@property
|
|
def started_at(self) -> datetime | None:
|
|
return self._find_state_transition_date([JobStatus.running])
|
|
|
|
@property
|
|
def completed_at(self) -> datetime | None:
|
|
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
|
|
|
@property
|
|
def logs(self) -> list[LogMessage]:
|
|
return self._logs[:]
|
|
|
|
def append_log(self, message: LogMessage) -> None:
|
|
self._logs.append(message)
|
|
|
|
|
|
class _SchedulerBackend(abc.ABC):
|
|
@abc.abstractmethod
|
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def shutdown(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def schedule(
|
|
self,
|
|
job: Job,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class _NaiveSchedulerBackend(_SchedulerBackend):
|
|
def __init__(self, timeout: int = 5):
|
|
self._timeout = timeout
|
|
self._loop = asyncio.new_event_loop()
|
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
|
self._thread.start()
|
|
|
|
def _run_loop(self) -> None:
|
|
asyncio.set_event_loop(self._loop)
|
|
self._loop.run_forever()
|
|
|
|
# When stopping the loop, give tasks a chance to finish
|
|
for task in asyncio.all_tasks(self._loop):
|
|
self._loop.run_until_complete(task)
|
|
self._loop.close()
|
|
|
|
async def shutdown(self) -> None:
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
self._thread.join()
|
|
|
|
def schedule(
|
|
self,
|
|
job: Job,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
|
) -> None:
|
|
async def do():
|
|
try:
|
|
job.status = JobStatus.running
|
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
|
job.status = JobStatus.completed
|
|
except Exception as e:
|
|
on_log_message_cb(str(e))
|
|
job.status = JobStatus.failed
|
|
logger.exception(f"Job {job.id} failed.")
|
|
|
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
|
|
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
|
pass
|
|
|
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
pass
|
|
|
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
pass
|
|
|
|
|
|
class DistributedJobScheduler(_SchedulerBackend):
|
|
"""A scheduler backend that supports distributed training jobs.
|
|
|
|
This scheduler uses torchrun to handle distributed training process spawning and coordination.
|
|
torchrun automatically handles:
|
|
- Process spawning
|
|
- Environment variable setup
|
|
- Process group initialization
|
|
- Error handling and process cleanup
|
|
"""
|
|
|
|
def __init__(self, timeout: int = 5):
|
|
self._timeout = timeout
|
|
self._loop = asyncio.new_event_loop()
|
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
|
self._thread.start()
|
|
self._active_jobs: dict[JobID, asyncio.subprocess.Process] = {}
|
|
|
|
def _run_loop(self) -> None:
|
|
asyncio.set_event_loop(self._loop)
|
|
self._loop.run_forever()
|
|
|
|
# When stopping the loop, give tasks a chance to finish
|
|
for task in asyncio.all_tasks(self._loop):
|
|
self._loop.run_until_complete(task)
|
|
self._loop.close()
|
|
|
|
async def shutdown(self) -> None:
|
|
# Clean up any remaining processes
|
|
for process in self._active_jobs.values():
|
|
if process.returncode is None: # Process is still running
|
|
process.terminate()
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=5)
|
|
except asyncio.TimeoutError:
|
|
process.kill()
|
|
await process.wait()
|
|
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
self._thread.join()
|
|
|
|
def schedule(
|
|
self,
|
|
job: Job,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
|
) -> None:
|
|
async def do():
|
|
try:
|
|
job.status = JobStatus.running
|
|
|
|
# If this is a distributed training job, use torchrun
|
|
if job.world_size > 1:
|
|
# Find the path to finetune_handler.py
|
|
from llama_stack.providers.inline.post_training.huggingface import finetune_handler
|
|
|
|
handler_path = Path(finetune_handler.__file__)
|
|
|
|
# Prepare arguments for the handler script
|
|
args = [
|
|
"torchrun",
|
|
f"--nproc_per_node={job.world_size}",
|
|
"--master_addr=localhost",
|
|
"--master_port=29500",
|
|
str(handler_path),
|
|
]
|
|
|
|
# Add arguments from the job.run_args dictionary as proper command-line flags
|
|
for arg_name, arg_value in job.run_args.items():
|
|
# Skip world_size as we've already handled it
|
|
if arg_name == "world_size":
|
|
continue
|
|
|
|
if arg_value is not None:
|
|
# Handle boolean flags
|
|
if isinstance(arg_value, bool):
|
|
if arg_value:
|
|
args.append(f"--{arg_name}")
|
|
else:
|
|
# For non-boolean values, we add the argument as a separate flag and value
|
|
args.append(f"--{arg_name}")
|
|
args.append(str(arg_value))
|
|
|
|
# Launch torchrun using asyncio
|
|
on_log_message_cb(f"Launching distributed training with {job.world_size} processes")
|
|
on_log_message_cb(f"Command: {' '.join(args)}")
|
|
|
|
# Make sure we capture stdout and stderr
|
|
process = await asyncio.create_subprocess_exec(
|
|
*args,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.STDOUT,
|
|
)
|
|
|
|
# Store process for this job
|
|
self._active_jobs[job.id] = process
|
|
|
|
# Start monitoring in a separate task so we don't block
|
|
asyncio.create_task(
|
|
self._monitor_process(job, process, None, on_log_message_cb, on_status_change_cb)
|
|
)
|
|
else:
|
|
# For single-device training, call the handler directly if provided
|
|
if job.handler:
|
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
|
job.status = JobStatus.completed
|
|
else:
|
|
on_log_message_cb("No handler function provided for single-device training")
|
|
job.status = JobStatus.failed
|
|
except Exception as e:
|
|
on_log_message_cb(str(e))
|
|
job.status = JobStatus.failed
|
|
logger.exception(f"Job {job.id} failed.")
|
|
|
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
|
|
|
async def _monitor_process(
|
|
self,
|
|
job: Job,
|
|
process: asyncio.subprocess.Process,
|
|
script_path: Path | None,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
) -> None:
|
|
"""Monitor a process until completion."""
|
|
try:
|
|
# Stream output from the process if stdout is available
|
|
if process.stdout is not None:
|
|
while True:
|
|
line = await process.stdout.readline()
|
|
if not line and process.returncode is not None:
|
|
break
|
|
if line:
|
|
on_log_message_cb(line.decode().strip())
|
|
else:
|
|
# If stdout is not available, just wait for the process to complete
|
|
on_log_message_cb("Process stdout not available, waiting for completion")
|
|
await process.wait()
|
|
|
|
# Wait for process to complete if not already done
|
|
if process.returncode is None:
|
|
await process.wait()
|
|
|
|
# Check if process failed
|
|
if process.returncode != 0:
|
|
on_log_message_cb(f"Training failed with return code {process.returncode}")
|
|
job.status = JobStatus.failed
|
|
else:
|
|
on_status_change_cb(JobStatus.completed)
|
|
job.status = JobStatus.completed
|
|
except Exception as e:
|
|
on_log_message_cb(f"Error monitoring process: {str(e)}")
|
|
job.status = JobStatus.failed
|
|
logger.exception(f"Error monitoring process for job {job.id}")
|
|
finally:
|
|
# Clean up temporary files
|
|
if script_path and script_path.exists():
|
|
script_path.unlink()
|
|
|
|
# Remove from active jobs
|
|
if job.id in self._active_jobs:
|
|
del self._active_jobs[job.id]
|
|
|
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
|
pass
|
|
|
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
pass
|
|
|
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
pass
|
|
|
|
|
|
_BACKENDS = {
|
|
"naive": _NaiveSchedulerBackend,
|
|
"distributed": DistributedJobScheduler,
|
|
}
|
|
|
|
|
|
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
|
try:
|
|
return _BACKENDS[backend]()
|
|
except KeyError as e:
|
|
raise ValueError(f"Unknown backend {backend}") from e
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, backend: str = "naive"):
|
|
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
|
self._jobs: dict[JobID, Job] = {}
|
|
self._backend = _get_backend_impl(backend)
|
|
|
|
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
|
msg = (datetime.now(timezone.utc), message)
|
|
# At least for the time being, until there's a better way to expose
|
|
# logs to users, log messages on console
|
|
logger.info(f"Job {job.id}: {message}")
|
|
job.append_log(msg)
|
|
self._backend.on_log_message_cb(job, msg)
|
|
|
|
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
job.status = status
|
|
self._backend.on_status_change_cb(job, status)
|
|
|
|
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
job.register_artifact(artifact)
|
|
self._backend.on_artifact_collected_cb(job, artifact)
|
|
|
|
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler | None, run_params: dict[str, Any]) -> JobID:
|
|
job = Job(type_, job_id, handler)
|
|
if job.id in self._jobs:
|
|
raise ValueError(f"Job {job.id} already exists")
|
|
|
|
# Set world size if provided
|
|
if "world_size" in run_params:
|
|
job.world_size = run_params["world_size"]
|
|
|
|
# Store all run parameters in the job's run_args dictionary
|
|
job.run_args = run_params
|
|
|
|
self._jobs[job.id] = job
|
|
job.status = JobStatus.scheduled
|
|
self._backend.schedule(
|
|
job,
|
|
functools.partial(self._on_log_message_cb, job),
|
|
functools.partial(self._on_status_change_cb, job),
|
|
functools.partial(self._on_artifact_collected_cb, job),
|
|
)
|
|
|
|
return job.id
|
|
|
|
def cancel(self, job_id: JobID) -> None:
|
|
self.get_job(job_id).cancel()
|
|
|
|
def get_job(self, job_id: JobID) -> Job:
|
|
try:
|
|
return self._jobs[job_id]
|
|
except KeyError as e:
|
|
raise ValueError(f"Job {job_id} not found") from e
|
|
|
|
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
|
jobs = list(self._jobs.values())
|
|
if type_:
|
|
jobs = [job for job in jobs if job._type == type_]
|
|
return jobs
|
|
|
|
async def shutdown(self):
|
|
# TODO: also cancel jobs once implemented
|
|
await self._backend.shutdown()
|