mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
back to working
This commit is contained in:
parent
1c7be17113
commit
7ca4418344
2 changed files with 192 additions and 5 deletions
|
@ -47,11 +47,91 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import get_gb, get_memory_stats, setup_torch_device, setup_data
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gb(to_convert: int) -> str:
|
||||||
|
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||||
|
Args:
|
||||||
|
to_convert: Memory value in bytes
|
||||||
|
Returns:
|
||||||
|
str: Memory value in GB formatted to 2 decimal places
|
||||||
|
"""
|
||||||
|
return f"{(to_convert / (1024**3)):.2f}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||||
|
"""Get memory statistics for the given device."""
|
||||||
|
stats = {
|
||||||
|
"system_memory": {
|
||||||
|
"total": get_gb(psutil.virtual_memory().total),
|
||||||
|
"available": get_gb(psutil.virtual_memory().available),
|
||||||
|
"used": get_gb(psutil.virtual_memory().used),
|
||||||
|
"percent": psutil.virtual_memory().percent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
||||||
|
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
||||||
|
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
||||||
|
}
|
||||||
|
elif device.type == "mps":
|
||||||
|
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"note": "MPS memory stats not directly available",
|
||||||
|
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
||||||
|
}
|
||||||
|
elif device.type == "cpu":
|
||||||
|
# For CPU, we track process memory usage
|
||||||
|
process = psutil.Process()
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"process_rss": get_gb(process.memory_info().rss),
|
||||||
|
"process_vms": get_gb(process.memory_info().vms),
|
||||||
|
"process_percent": process.memory_percent(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_device(device_str: str) -> torch.device:
|
||||||
|
"""Initialize and validate a PyTorch device.
|
||||||
|
This function handles device initialization and validation for different device types:
|
||||||
|
- CUDA: Validates CUDA availability and handles device selection
|
||||||
|
- MPS: Validates MPS availability for Apple Silicon
|
||||||
|
- CPU: Basic validation
|
||||||
|
- HPU: Raises error as it's not supported
|
||||||
|
Args:
|
||||||
|
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
||||||
|
Returns:
|
||||||
|
torch.device: The initialized and validated device
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If device initialization fails or device is not supported
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device = torch.device(device_str)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||||
|
|
||||||
|
# Validate device capabilities
|
||||||
|
if device.type == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||||
|
)
|
||||||
|
if device.index is None:
|
||||||
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
|
elif device.type == "mps":
|
||||||
|
if not torch.backends.mps.is_available():
|
||||||
|
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||||
|
elif device.type == "hpu":
|
||||||
|
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
class HFFinetuningSingleDevice:
|
class HFFinetuningSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -182,6 +262,19 @@ class HFFinetuningSingleDevice:
|
||||||
remove_columns=ds.column_names,
|
remove_columns=ds.column_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
|
"""Load dataset from llama stack dataset provider"""
|
||||||
|
try:
|
||||||
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
limit=-1,
|
||||||
|
)
|
||||||
|
if not isinstance(all_rows.data, list):
|
||||||
|
raise RuntimeError("Expected dataset data to be a list")
|
||||||
|
return all_rows.data
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||||
|
|
||||||
def _run_training_sync(
|
def _run_training_sync(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -234,7 +327,7 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await setup_data(config.data_config.dataset_id, self.datasetio_api)
|
rows = await self._setup_data(config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
if not self.validate_dataset_format(rows):
|
||||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
@ -605,4 +698,4 @@ class HFFinetuningSingleDevice:
|
||||||
return memory_stats, checkpoints if checkpoints else None
|
return memory_stats, checkpoints if checkpoints else None
|
||||||
finally:
|
finally:
|
||||||
memory_stats["final"] = get_memory_stats(device)
|
memory_stats["final"] = get_memory_stats(device)
|
||||||
gc.collect()
|
gc.collect()
|
|
@ -44,11 +44,105 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
from ..utils import get_gb, get_memory_stats, setup_torch_device, setup_data
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gb(to_convert: int) -> str:
|
||||||
|
"""Converts memory stats to GB and formats to 2 decimal places.
|
||||||
|
Args:
|
||||||
|
to_convert: Memory value in bytes
|
||||||
|
Returns:
|
||||||
|
str: Memory value in GB formatted to 2 decimal places
|
||||||
|
"""
|
||||||
|
return f"{(to_convert / (1024**3)):.2f}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||||
|
"""Get memory statistics for the given device."""
|
||||||
|
stats = {
|
||||||
|
"system_memory": {
|
||||||
|
"total": get_gb(psutil.virtual_memory().total),
|
||||||
|
"available": get_gb(psutil.virtual_memory().available),
|
||||||
|
"used": get_gb(psutil.virtual_memory().used),
|
||||||
|
"percent": psutil.virtual_memory().percent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.type == "cuda":
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"allocated": get_gb(torch.cuda.memory_allocated(device)),
|
||||||
|
"reserved": get_gb(torch.cuda.memory_reserved(device)),
|
||||||
|
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
|
||||||
|
}
|
||||||
|
elif device.type == "mps":
|
||||||
|
# MPS doesn't provide direct memory stats, but we can track system memory
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"note": "MPS memory stats not directly available",
|
||||||
|
"system_memory_used": get_gb(psutil.virtual_memory().used),
|
||||||
|
}
|
||||||
|
elif device.type == "cpu":
|
||||||
|
# For CPU, we track process memory usage
|
||||||
|
process = psutil.Process()
|
||||||
|
stats["device_memory"] = {
|
||||||
|
"process_rss": get_gb(process.memory_info().rss),
|
||||||
|
"process_vms": get_gb(process.memory_info().vms),
|
||||||
|
"process_percent": process.memory_percent(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_device(device_str: str) -> torch.device:
|
||||||
|
"""Initialize and validate a PyTorch device.
|
||||||
|
This function handles device initialization and validation for different device types:
|
||||||
|
- CUDA: Validates CUDA availability and handles device selection
|
||||||
|
- MPS: Validates MPS availability for Apple Silicon
|
||||||
|
- CPU: Basic validation
|
||||||
|
- HPU: Raises error as it's not supported
|
||||||
|
Args:
|
||||||
|
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
||||||
|
Returns:
|
||||||
|
torch.device: The initialized and validated device
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If device initialization fails or device is not supported
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device = torch.device(device_str)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||||
|
|
||||||
|
# Validate device capabilities
|
||||||
|
if device.type == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||||
|
)
|
||||||
|
if device.index is None:
|
||||||
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
|
elif device.type == "mps":
|
||||||
|
if not torch.backends.mps.is_available():
|
||||||
|
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||||
|
elif device.type == "hpu":
|
||||||
|
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
|
"""Load dataset from llama stack dataset provider"""
|
||||||
|
try:
|
||||||
|
all_rows = await datasetio_api.iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
limit=-1,
|
||||||
|
)
|
||||||
|
if not isinstance(all_rows.data, list):
|
||||||
|
raise RuntimeError("Expected dataset data to be a list")
|
||||||
|
return all_rows.data
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
class HFDPOAlignmentSingleDevice:
|
class HFDPOAlignmentSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -216,7 +310,7 @@ class HFDPOAlignmentSingleDevice:
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await setup_data(config.data_config.dataset_id, self.datasetio_api)
|
rows = await setup_data(self.datasetio_api, config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
if not self.validate_dataset_format(rows):
|
||||||
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
|
raise ValueError("Dataset is missing required fields: prompt, chosen, rejected")
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue