mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 02:30:24 +00:00
feat: handle graceful shutdown
currently this impl hangs because of `trainer.train()` blocking. Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess. All of these steps need to be in a subprocess or else different devices are used which causes torch errors. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
ff246d890a
commit
46c5b14a22
4 changed files with 387 additions and 312 deletions
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
def evacuate_model_from_device(model, device: str):
|
||||
"""Safely clear a model from memory and free device resources.
|
||||
This function handles the proper cleanup of a model by:
|
||||
1. Moving the model to CPU if it's on a non-CPU device
|
||||
2. Deleting the model object to free memory
|
||||
3. Running garbage collection
|
||||
4. Clearing CUDA cache if the model was on a CUDA device
|
||||
Args:
|
||||
model: The PyTorch model to clear
|
||||
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
|
||||
Note:
|
||||
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
|
||||
- For MPS devices, only moves the model to CPU (no cache clearing available)
|
||||
- For CPU devices, only deletes the model object and runs garbage collection
|
||||
"""
|
||||
if device != "cpu":
|
||||
model.to("cpu")
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
if device == "cuda":
|
||||
# we need to import such that this is only imported when the method is called
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
Loading…
Add table
Add a link
Reference in a new issue