fix the chnages that requested in review

This commit is contained in:
Nehanth 2025-07-29 17:22:51 +00:00
parent 41a45580e0
commit 518bf2fc34
6 changed files with 42 additions and 67 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
import multiprocessing
import os
import signal
import sys
@ -34,7 +33,7 @@ def setup_environment():
os.environ["MKL_NUM_THREADS"] = "1"
def get_gb(to_convert: int) -> str:
def bytes_to_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"total": bytes_to_gb(psutil.virtual_memory().total),
"available": bytes_to_gb(psutil.virtual_memory().available),
"used": bytes_to_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_rss": bytes_to_gb(process.memory_info().rss),
"process_vms": bytes_to_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
return device
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
@ -268,12 +267,3 @@ def create_checkpoints(
checkpoints.append(checkpoint)
return checkpoints
def setup_multiprocessing_for_device(device: torch.device):
"""Setup multiprocessing start method based on device type.
Args:
device: The device being used for training
"""
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)