From 7ca44183440a69b8d5dcfb3571484d061bee4f76 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 23 Jul 2025 17:38:58 +0000 Subject: [PATCH] back to working --- .../recipes/finetune_single_device.py | 99 ++++++++++++++++++- .../recipes/finetune_single_device_dpo.py | 98 +++++++++++++++++- 2 files changed, 192 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 762b80e0b..8f8faee9e 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -47,11 +47,91 @@ from llama_stack.apis.post_training import ( ) from ..config import HuggingFacePostTrainingConfig -from ..utils import get_gb, get_memory_stats, setup_torch_device, setup_data logger = logging.getLogger(__name__) +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + class HFFinetuningSingleDevice: def __init__( self, @@ -182,6 +262,19 @@ class HFFinetuningSingleDevice: remove_columns=ds.column_names, ) + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + def _run_training_sync( self, model: str, @@ -234,7 +327,7 @@ class HFFinetuningSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await setup_data(config.data_config.dataset_id, self.datasetio_api) + rows = await self._setup_data(config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") logger.info(f"Loaded {len(rows)} rows from dataset") @@ -605,4 +698,4 @@ class HFFinetuningSingleDevice: return memory_stats, checkpoints if checkpoints else None finally: memory_stats["final"] = get_memory_stats(device) - gc.collect() + gc.collect() \ No newline at end of file diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index a16d40314..39bf06d0e 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -44,11 +44,105 @@ from llama_stack.apis.post_training import ( ) from ..config import HuggingFacePostTrainingConfig -from ..utils import get_gb, get_memory_stats, setup_torch_device, setup_data logger = logging.getLogger(__name__) +def get_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": get_gb(psutil.virtual_memory().total), + "available": get_gb(psutil.virtual_memory().available), + "used": get_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": get_gb(torch.cuda.memory_allocated(device)), + "reserved": get_gb(torch.cuda.memory_reserved(device)), + "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": get_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": get_gb(process.memory_info().rss), + "process_vms": get_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + +async def setup_data(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + class HFDPOAlignmentSingleDevice: def __init__( self, @@ -216,7 +310,7 @@ class HFDPOAlignmentSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await setup_data(config.data_config.dataset_id, self.datasetio_api) + rows = await setup_data(self.datasetio_api, config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: prompt, chosen, rejected") logger.info(f"Loaded {len(rows)} rows from dataset")