TMP test launching distributed training from inline provider

Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
James Kunstle 2025-03-20 18:56:52 -07:00
parent 68000499f7
commit 97d54778a3
2 changed files with 26 additions and 3 deletions

View file

@ -1,4 +1,5 @@
import asyncio
import pathlib
import tempfile
import typing
from asyncio import subprocess
@ -288,9 +289,19 @@ class FullPrecisionFineTuning:
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
"""
training_subproc = await asyncio.create_subprocess_shell(
'echo "yay Im running in a subprocess: $$"; sleep 5; echo "exiting subprocess $$"'
)
# assumes that SPMD training file is next to current file
train_file = pathlib.Path(__file__).resolve() / "train.py"
NGPU = 2
command = f"""
torchrun \
--nproc_per_node {NGPU} \
--rdzv_backend gloo \
--rdzv_endpoint="localhost:0" \
{train_file} \
"""
training_subproc = await asyncio.create_subprocess_shell(cmd=command)
set_subproc_ref_callback(training_subproc)
await training_subproc.wait()
set_status_callback(JobStatus.completed)

View file

@ -0,0 +1,12 @@
import os
import time
CURRENT_LOCAL_RANK = os.getenv("LOCAL_RANK", "UNKNOWN")
CURRENT_RANK = os.getenv("RANK", "UNKNOWN")
CURRENT_WS = os.getenv("WORLD_SIZE", "UNKNOWN")
print(f"Hello from training script! LR:({CURRENT_LOCAL_RANK}) R:({CURRENT_RANK}) WS:({CURRENT_WS})")
for i in range(30):
print(f"LR:({CURRENT_LOCAL_RANK}) | {i}")
time.sleep(1)