feat: handle graceful shutdown

currently this impl hangs because of `trainer.train()` blocking.

Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess.

All of these steps need to be in a subprocess or else different devices are used which causes torch errors.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-14 15:43:41 -04:00
parent ff246d890a
commit 46c5b14a22
4 changed files with 387 additions and 312 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import os
import time
@ -47,6 +46,7 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice:
checkpoints.append(checkpoint)
# clean up the memory after training finishes
if self._device.type != "cpu":
self._model.to("cpu")
torch.cuda.empty_cache()
del self._model
gc.collect()
evacuate_model_from_device(self._model, self._device.type)
return (memory_stats, checkpoints)