mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
TMP test launching distributed training from inline provider
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
68000499f7
commit
97d54778a3
2 changed files with 26 additions and 3 deletions
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import typing
|
import typing
|
||||||
from asyncio import subprocess
|
from asyncio import subprocess
|
||||||
|
@ -288,9 +289,19 @@ class FullPrecisionFineTuning:
|
||||||
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
|
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
|
||||||
"""
|
"""
|
||||||
|
|
||||||
training_subproc = await asyncio.create_subprocess_shell(
|
# assumes that SPMD training file is next to current file
|
||||||
'echo "yay Im running in a subprocess: $$"; sleep 5; echo "exiting subprocess $$"'
|
train_file = pathlib.Path(__file__).resolve() / "train.py"
|
||||||
)
|
NGPU = 2
|
||||||
|
|
||||||
|
command = f"""
|
||||||
|
torchrun \
|
||||||
|
--nproc_per_node {NGPU} \
|
||||||
|
--rdzv_backend gloo \
|
||||||
|
--rdzv_endpoint="localhost:0" \
|
||||||
|
{train_file} \
|
||||||
|
"""
|
||||||
|
|
||||||
|
training_subproc = await asyncio.create_subprocess_shell(cmd=command)
|
||||||
set_subproc_ref_callback(training_subproc)
|
set_subproc_ref_callback(training_subproc)
|
||||||
await training_subproc.wait()
|
await training_subproc.wait()
|
||||||
set_status_callback(JobStatus.completed)
|
set_status_callback(JobStatus.completed)
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
CURRENT_LOCAL_RANK = os.getenv("LOCAL_RANK", "UNKNOWN")
|
||||||
|
CURRENT_RANK = os.getenv("RANK", "UNKNOWN")
|
||||||
|
CURRENT_WS = os.getenv("WORLD_SIZE", "UNKNOWN")
|
||||||
|
|
||||||
|
print(f"Hello from training script! LR:({CURRENT_LOCAL_RANK}) R:({CURRENT_RANK}) WS:({CURRENT_WS})")
|
||||||
|
|
||||||
|
for i in range(30):
|
||||||
|
print(f"LR:({CURRENT_LOCAL_RANK}) | {i}")
|
||||||
|
time.sleep(1)
|
Loading…
Add table
Add a link
Reference in a new issue