mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
fix the chnages that requested in review
This commit is contained in:
parent
41a45580e0
commit
518bf2fc34
6 changed files with 42 additions and 67 deletions
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -34,7 +33,7 @@ def setup_environment():
|
|||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
|
||||
def get_gb(to_convert: int) -> str:
|
||||
def bytes_to_gb(to_convert: int) -> str:
|
||||
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||
Args:
|
||||
to_convert: Memory value in bytes
|
||||
|
|
@ -48,31 +47,31 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
|||
"""Get memory statistics for the given device."""
|
||||
stats = {
|
||||
"system_memory": {
|
||||
"total": get_gb(psutil.virtual_memory().total),
|
||||
"available": get_gb(psutil.virtual_memory().available),
|
||||
"used": get_gb(psutil.virtual_memory().used),
|
||||
"total": bytes_to_gb(psutil.virtual_memory().total),
|
||||
"available": bytes_to_gb(psutil.virtual_memory().available),
|
||||
"used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
"percent": psutil.virtual_memory().percent,
|
||||
}
|
||||
}
|
||||
|
||||
if device.type == "cuda":
|
||||
stats["device_memory"] = {
|
||||
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
||||
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
|
||||
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
|
||||
}
|
||||
elif device.type == "mps":
|
||||
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||
stats["device_memory"] = {
|
||||
"note": "MPS memory stats not directly available",
|
||||
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
||||
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
|
||||
}
|
||||
elif device.type == "cpu":
|
||||
# For CPU, we track process memory usage
|
||||
process = psutil.Process()
|
||||
stats["device_memory"] = {
|
||||
"process_rss": get_gb(process.memory_info().rss),
|
||||
"process_vms": get_gb(process.memory_info().vms),
|
||||
"process_rss": bytes_to_gb(process.memory_info().rss),
|
||||
"process_vms": bytes_to_gb(process.memory_info().vms),
|
||||
"process_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +114,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|||
return device
|
||||
|
||||
|
||||
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Load dataset from llama stack dataset provider"""
|
||||
try:
|
||||
all_rows = await datasetio_api.iterrows(
|
||||
|
|
@ -268,12 +267,3 @@ def create_checkpoints(
|
|||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def setup_multiprocessing_for_device(device: torch.device):
|
||||
"""Setup multiprocessing start method based on device type.
|
||||
Args:
|
||||
device: The device being used for training
|
||||
"""
|
||||
if device.type in ["cuda", "mps"]:
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue