mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 18:00:04 +00:00
feat: handle graceful shutdown
currently this impl hangs because of `trainer.train()` blocking. Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess. All of these steps need to be in a subprocess or else different devices are used which causes torch errors. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
ff246d890a
commit
46c5b14a22
4 changed files with 387 additions and 312 deletions
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
|
@ -47,6 +46,7 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
|
|
@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoints.append(checkpoint)
|
||||
|
||||
# clean up the memory after training finishes
|
||||
if self._device.type != "cpu":
|
||||
self._model.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
del self._model
|
||||
gc.collect()
|
||||
evacuate_model_from_device(self._model, self._device.type)
|
||||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue