This commit is contained in:
Botao Chen 2025-01-14 18:01:49 -08:00
parent 89e3f81520
commit a77b8126d0

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import os
import time
@ -580,6 +581,15 @@ class LoraFinetuningSingleDevice:
checkpoint.training_metrics = training_metrics
checkpoints.append(checkpoint)
# clean up the memory after training finishes
self._model.to("cpu")
del self._model
gc.collect()
torch.cuda.empty_cache()
print("Allocated:", torch.cuda.memory_allocated() / 1e6, "MB")
print("Reserved: ", torch.cuda.memory_reserved() / 1e6, "MB")
return (memory_stats, checkpoints)
async def validation(self) -> Tuple[float, float]: