llama-stack-mirror/llama_stack/providers/inline/post_training/common/utils.py
Charlie Doern 46c5b14a22 feat: handle graceful shutdown
currently this impl hangs because of `trainer.train()` blocking.

Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess.

All of these steps need to be in a subprocess or else different devices are used which causes torch errors.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-05-16 16:41:24 -04:00

35 lines
1.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()