mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
feat: handle graceful shutdown
currently this impl hangs because of `trainer.train()` blocking. Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess. All of these steps need to be in a subprocess or else different devices are used which causes torch errors. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
ff246d890a
commit
46c5b14a22
4 changed files with 387 additions and 312 deletions
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
def evacuate_model_from_device(model, device: str):
|
||||||
|
"""Safely clear a model from memory and free device resources.
|
||||||
|
This function handles the proper cleanup of a model by:
|
||||||
|
1. Moving the model to CPU if it's on a non-CPU device
|
||||||
|
2. Deleting the model object to free memory
|
||||||
|
3. Running garbage collection
|
||||||
|
4. Clearing CUDA cache if the model was on a CUDA device
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to clear
|
||||||
|
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
|
||||||
|
Note:
|
||||||
|
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
|
||||||
|
- For MPS devices, only moves the model to CPU (no cache clearing available)
|
||||||
|
- For CPU devices, only deletes the model object and runs garbage collection
|
||||||
|
"""
|
||||||
|
if device != "cpu":
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
# we need to import such that this is only imported when the method is called
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
|
@ -7,16 +7,26 @@
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
# Set tokenizer parallelism environment variable
|
# Set tokenizer parallelism environment variable
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
# Force PyTorch to use OpenBLAS instead of MKL
|
||||||
|
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
||||||
|
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
|
||||||
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
|
@ -86,10 +96,46 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]:
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_device(device_str: str) -> torch.device:
|
||||||
|
"""Initialize and validate a PyTorch device.
|
||||||
|
This function handles device initialization and validation for different device types:
|
||||||
|
- CUDA: Validates CUDA availability and handles device selection
|
||||||
|
- MPS: Validates MPS availability for Apple Silicon
|
||||||
|
- CPU: Basic validation
|
||||||
|
- HPU: Raises error as it's not supported
|
||||||
|
Args:
|
||||||
|
device_str: String specifying the device ('cuda', 'cpu', 'mps')
|
||||||
|
Returns:
|
||||||
|
torch.device: The initialized and validated device
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If device initialization fails or device is not supported
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device = torch.device(device_str)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
||||||
|
|
||||||
|
# Validate device capabilities
|
||||||
|
if device.type == "cuda":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
||||||
|
)
|
||||||
|
if device.index is None:
|
||||||
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
|
elif device.type == "mps":
|
||||||
|
if not torch.backends.mps.is_available():
|
||||||
|
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
||||||
|
elif device.type == "hpu":
|
||||||
|
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
class HFFinetuningSingleDevice:
|
class HFFinetuningSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
job_uuid,
|
job_uuid: str,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
):
|
):
|
||||||
|
@ -216,58 +262,120 @@ class HFFinetuningSingleDevice:
|
||||||
remove_columns=ds.column_names,
|
remove_columns=ds.column_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
|
||||||
|
"""Load dataset from llama stack dataset provider"""
|
||||||
|
try:
|
||||||
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
limit=-1,
|
||||||
|
)
|
||||||
|
if not isinstance(all_rows.data, list):
|
||||||
|
raise RuntimeError("Expected dataset data to be a list")
|
||||||
|
return all_rows.data
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
||||||
|
|
||||||
|
def _run_training_sync(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
provider_config: dict[str, Any],
|
||||||
|
peft_config: LoraConfig | None,
|
||||||
|
config: dict[str, Any],
|
||||||
|
output_dir_path: Path | None,
|
||||||
|
) -> None:
|
||||||
|
"""Synchronous wrapper for running training process.
|
||||||
|
This method serves as a bridge between the multiprocessing Process and the async training function.
|
||||||
|
It creates a new event loop to run the async training process.
|
||||||
|
Args:
|
||||||
|
model: The model identifier to load
|
||||||
|
dataset_id: ID of the dataset to use for training
|
||||||
|
provider_config: Configuration specific to the HuggingFace provider
|
||||||
|
peft_config: Optional LoRA configuration
|
||||||
|
config: General training configuration
|
||||||
|
output_dir_path: Optional path to save the model
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
logger.info("Starting training process with async wrapper")
|
||||||
|
asyncio.run(
|
||||||
|
self._run_training(
|
||||||
|
model=model,
|
||||||
|
provider_config=provider_config,
|
||||||
|
peft_config=peft_config,
|
||||||
|
config=config,
|
||||||
|
output_dir_path=output_dir_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def load_dataset(
|
async def load_dataset(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
provider_config: HuggingFacePostTrainingConfig,
|
provider_config: HuggingFacePostTrainingConfig,
|
||||||
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
) -> tuple[Dataset, Dataset, AutoTokenizer]:
|
||||||
"""Load and preprocess the dataset for training.
|
"""Load and prepare the dataset for training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model identifier to load
|
model: The model identifier to load
|
||||||
config: Training configuration containing dataset settings
|
config: Training configuration
|
||||||
provider_config: Provider-specific configuration
|
provider_config: Provider-specific configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple containing:
|
tuple: (train_dataset, eval_dataset, tokenizer)
|
||||||
- Training dataset
|
|
||||||
- Evaluation dataset
|
|
||||||
- Tokenizer
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If dataset is missing required fields
|
|
||||||
RuntimeError: If tokenizer initialization fails
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized"
|
# Validate data config
|
||||||
rows = await self._setup_data(config.data_config.dataset_id)
|
if not config.data_config:
|
||||||
|
raise ValueError("DataConfig is required for training")
|
||||||
|
|
||||||
# Validate that the dataset has the required fields for training
|
# Load dataset
|
||||||
|
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
|
rows = await self._setup_data(config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
if not self.validate_dataset_format(rows):
|
||||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||||
|
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
|
||||||
# Initialize tokenizer with model-specific config
|
# Initialize tokenizer
|
||||||
|
logger.info(f"Initializing tokenizer for model: {model}")
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||||
# Set up tokenizer defaults
|
|
||||||
|
# Set pad token to eos token if not present
|
||||||
|
# This is common for models that don't have a dedicated pad token
|
||||||
if not tokenizer.pad_token:
|
if not tokenizer.pad_token:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Set padding side to right for causal language modeling
|
||||||
|
# This ensures that padding tokens don't interfere with the model's ability
|
||||||
|
# to predict the next token in the sequence
|
||||||
tokenizer.padding_side = "right"
|
tokenizer.padding_side = "right"
|
||||||
|
|
||||||
|
# Set truncation side to right to keep the beginning of the sequence
|
||||||
|
# This is important for maintaining context and instruction format
|
||||||
tokenizer.truncation_side = "right"
|
tokenizer.truncation_side = "right"
|
||||||
|
|
||||||
|
# Set model max length to match provider config
|
||||||
|
# This ensures consistent sequence lengths across the training process
|
||||||
tokenizer.model_max_length = provider_config.max_seq_length
|
tokenizer.model_max_length = provider_config.max_seq_length
|
||||||
|
|
||||||
|
logger.info("Tokenizer initialized successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||||
|
|
||||||
# Create and preprocess dataset
|
# Create and preprocess dataset
|
||||||
|
logger.info("Creating and preprocessing dataset")
|
||||||
try:
|
try:
|
||||||
ds = self._create_dataset(rows, config, provider_config)
|
ds = self._create_dataset(rows, config, provider_config)
|
||||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||||
|
logger.info(f"Dataset created with {len(ds)} examples")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||||
|
|
||||||
# Split dataset into train and validation
|
# Split dataset
|
||||||
|
logger.info("Splitting dataset into train and validation sets")
|
||||||
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
|
||||||
return train_val_split["train"], train_val_split["test"], tokenizer
|
train_dataset = train_val_split["train"]
|
||||||
|
eval_dataset = train_val_split["test"]
|
||||||
|
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset, tokenizer
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
|
@ -276,15 +384,12 @@ class HFFinetuningSingleDevice:
|
||||||
provider_config: HuggingFacePostTrainingConfig,
|
provider_config: HuggingFacePostTrainingConfig,
|
||||||
) -> AutoModelForCausalLM:
|
) -> AutoModelForCausalLM:
|
||||||
"""Load and initialize the model for training.
|
"""Load and initialize the model for training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model identifier to load
|
model: The model identifier to load
|
||||||
device: The device to load the model onto
|
device: The device to load the model onto
|
||||||
provider_config: Provider-specific configuration
|
provider_config: Provider-specific configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The loaded and initialized model
|
The loaded and initialized model
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If model loading fails
|
RuntimeError: If model loading fails
|
||||||
"""
|
"""
|
||||||
|
@ -293,18 +398,201 @@ class HFFinetuningSingleDevice:
|
||||||
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
||||||
model_obj = AutoModelForCausalLM.from_pretrained(
|
model_obj = AutoModelForCausalLM.from_pretrained(
|
||||||
model,
|
model,
|
||||||
torch_dtype="auto",
|
torch_dtype="auto" if device.type != "cpu" else "float32",
|
||||||
quantization_config=None,
|
quantization_config=None,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**provider_config.model_specific_config,
|
**provider_config.model_specific_config,
|
||||||
)
|
)
|
||||||
if model_obj.device != device:
|
# Always move model to specified device
|
||||||
model_obj = model_obj.to(device)
|
model_obj = model_obj.to(device)
|
||||||
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
||||||
return model_obj
|
return model_obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
||||||
|
|
||||||
|
def setup_training_args(
|
||||||
|
self,
|
||||||
|
config: TrainingConfig,
|
||||||
|
provider_config: HuggingFacePostTrainingConfig,
|
||||||
|
device: torch.device,
|
||||||
|
output_dir_path: Path | None,
|
||||||
|
steps_per_epoch: int,
|
||||||
|
) -> SFTConfig:
|
||||||
|
"""Setup training arguments.
|
||||||
|
Args:
|
||||||
|
config: Training configuration
|
||||||
|
provider_config: Provider-specific configuration
|
||||||
|
device: The device to train on
|
||||||
|
output_dir_path: Optional path to save the model
|
||||||
|
steps_per_epoch: Number of steps per epoch
|
||||||
|
Returns:
|
||||||
|
Configured SFTConfig object
|
||||||
|
"""
|
||||||
|
logger.info("Configuring training arguments")
|
||||||
|
lr = 2e-5
|
||||||
|
if config.optimizer_config:
|
||||||
|
lr = config.optimizer_config.lr
|
||||||
|
logger.info(f"Using custom learning rate: {lr}")
|
||||||
|
|
||||||
|
# Validate data config
|
||||||
|
if not config.data_config:
|
||||||
|
raise ValueError("DataConfig is required for training")
|
||||||
|
data_config = config.data_config
|
||||||
|
|
||||||
|
# Calculate steps
|
||||||
|
total_steps = steps_per_epoch * config.n_epochs
|
||||||
|
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||||
|
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
|
||||||
|
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
|
||||||
|
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||||
|
|
||||||
|
logger.info("Training configuration:")
|
||||||
|
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||||
|
logger.info(f"- Total steps: {total_steps}")
|
||||||
|
logger.info(f"- Max steps: {max_steps}")
|
||||||
|
logger.info(f"- Eval steps: {eval_steps}")
|
||||||
|
logger.info(f"- Save steps: {save_steps}")
|
||||||
|
logger.info(f"- Logging steps: {logging_steps}")
|
||||||
|
|
||||||
|
# Configure save strategy
|
||||||
|
save_strategy = "no"
|
||||||
|
if output_dir_path:
|
||||||
|
save_strategy = "steps"
|
||||||
|
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||||
|
|
||||||
|
return SFTConfig(
|
||||||
|
max_steps=max_steps,
|
||||||
|
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
||||||
|
num_train_epochs=config.n_epochs,
|
||||||
|
per_device_train_batch_size=data_config.batch_size,
|
||||||
|
fp16=device.type == "cuda",
|
||||||
|
bf16=False, # Causes CPU issues.
|
||||||
|
eval_strategy="steps",
|
||||||
|
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||||
|
save_strategy=save_strategy,
|
||||||
|
report_to="none",
|
||||||
|
max_seq_length=provider_config.max_seq_length,
|
||||||
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||||
|
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||||
|
learning_rate=lr,
|
||||||
|
warmup_ratio=provider_config.warmup_ratio,
|
||||||
|
weight_decay=provider_config.weight_decay,
|
||||||
|
remove_unused_columns=False,
|
||||||
|
dataloader_pin_memory=provider_config.dataloader_pin_memory,
|
||||||
|
dataloader_num_workers=provider_config.dataloader_num_workers,
|
||||||
|
dataset_text_field="text",
|
||||||
|
packing=False,
|
||||||
|
load_best_model_at_end=True if output_dir_path else False,
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
eval_steps=eval_steps,
|
||||||
|
save_steps=save_steps,
|
||||||
|
logging_steps=logging_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_model(
|
||||||
|
self,
|
||||||
|
model_obj: AutoModelForCausalLM,
|
||||||
|
trainer: SFTTrainer,
|
||||||
|
peft_config: LoraConfig | None,
|
||||||
|
output_dir_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Save the trained model.
|
||||||
|
Args:
|
||||||
|
model_obj: The model to save
|
||||||
|
trainer: The trainer instance
|
||||||
|
peft_config: Optional LoRA configuration
|
||||||
|
output_dir_path: Path to save the model
|
||||||
|
"""
|
||||||
|
logger.info("Saving final model")
|
||||||
|
model_obj.config.use_cache = True
|
||||||
|
|
||||||
|
if peft_config:
|
||||||
|
logger.info("Merging LoRA weights with base model")
|
||||||
|
model_obj = trainer.model.merge_and_unload()
|
||||||
|
else:
|
||||||
|
model_obj = trainer.model
|
||||||
|
|
||||||
|
save_path = output_dir_path / "merged_model"
|
||||||
|
logger.info(f"Saving model to {save_path}")
|
||||||
|
model_obj.save_pretrained(save_path)
|
||||||
|
|
||||||
|
async def _run_training(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
provider_config: dict[str, Any],
|
||||||
|
peft_config: LoraConfig | None,
|
||||||
|
config: dict[str, Any],
|
||||||
|
output_dir_path: Path | None,
|
||||||
|
) -> None:
|
||||||
|
"""Run the training process with signal handling."""
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
"""Handle termination signals gracefully."""
|
||||||
|
logger.info(f"Received signal {signum}, initiating graceful shutdown")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Convert config dicts back to objects
|
||||||
|
logger.info("Initializing configuration objects")
|
||||||
|
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||||
|
config_obj = TrainingConfig(**config)
|
||||||
|
|
||||||
|
# Initialize and validate device
|
||||||
|
device = setup_torch_device(provider_config_obj.device)
|
||||||
|
logger.info(f"Using device '{device}'")
|
||||||
|
|
||||||
|
# Load dataset and tokenizer
|
||||||
|
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||||
|
|
||||||
|
# Calculate steps per epoch
|
||||||
|
if not config_obj.data_config:
|
||||||
|
raise ValueError("DataConfig is required for training")
|
||||||
|
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
|
||||||
|
|
||||||
|
# Setup training arguments
|
||||||
|
training_args = self.setup_training_args(
|
||||||
|
config_obj,
|
||||||
|
provider_config_obj,
|
||||||
|
device,
|
||||||
|
output_dir_path,
|
||||||
|
steps_per_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model_obj = self.load_model(model, device, provider_config_obj)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
logger.info("Initializing SFTTrainer")
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model_obj,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
peft_config=peft_config,
|
||||||
|
args=training_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Train
|
||||||
|
logger.info("Starting training")
|
||||||
|
trainer.train()
|
||||||
|
logger.info("Training completed successfully")
|
||||||
|
|
||||||
|
# Save final model if output directory is provided
|
||||||
|
if output_dir_path:
|
||||||
|
self.save_model(model_obj, trainer, peft_config, output_dir_path)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up resources
|
||||||
|
logger.info("Cleaning up resources")
|
||||||
|
if hasattr(trainer, "model"):
|
||||||
|
evacuate_model_from_device(trainer.model, device.type)
|
||||||
|
del trainer
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Cleanup completed")
|
||||||
|
|
||||||
async def train(
|
async def train(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -315,50 +603,21 @@ class HFFinetuningSingleDevice:
|
||||||
provider_config: HuggingFacePostTrainingConfig,
|
provider_config: HuggingFacePostTrainingConfig,
|
||||||
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
|
||||||
"""Train a model using HuggingFace's SFTTrainer"""
|
"""Train a model using HuggingFace's SFTTrainer"""
|
||||||
try:
|
# Initialize and validate device
|
||||||
device = torch.device(provider_config.device)
|
device = setup_torch_device(provider_config.device)
|
||||||
except RuntimeError as e:
|
|
||||||
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
|
|
||||||
|
|
||||||
# Detect device type and validate
|
|
||||||
if device.type == "cuda":
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
|
|
||||||
)
|
|
||||||
# map unqualified 'cuda' to current device
|
|
||||||
if device.index is None:
|
|
||||||
device = torch.device(device.type, torch.cuda.current_device())
|
|
||||||
elif device.type == "mps":
|
|
||||||
if not torch.backends.mps.is_available():
|
|
||||||
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
|
|
||||||
elif device.type == "hpu":
|
|
||||||
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
|
|
||||||
|
|
||||||
logger.info(f"Using device '{device}'")
|
logger.info(f"Using device '{device}'")
|
||||||
|
|
||||||
output_dir_path = None
|
output_dir_path = None
|
||||||
if output_dir:
|
if output_dir:
|
||||||
output_dir_path = Path(output_dir)
|
output_dir_path = Path(output_dir)
|
||||||
|
|
||||||
# Track memory stats throughout training
|
# Track memory stats
|
||||||
memory_stats = {
|
memory_stats = {
|
||||||
"initial": get_memory_stats(device),
|
"initial": get_memory_stats(device),
|
||||||
"after_model_load": None,
|
|
||||||
"after_training": None,
|
"after_training": None,
|
||||||
"final": None,
|
"final": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Validate data config
|
|
||||||
if not config.data_config:
|
|
||||||
raise ValueError("DataConfig is required for training")
|
|
||||||
|
|
||||||
# Load dataset and tokenizer
|
|
||||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config, provider_config)
|
|
||||||
|
|
||||||
# Load model with model-specific config
|
|
||||||
model_obj = self.load_model(model, device, provider_config)
|
|
||||||
memory_stats["after_model_load"] = get_memory_stats(device)
|
|
||||||
|
|
||||||
# Configure LoRA
|
# Configure LoRA
|
||||||
peft_config = None
|
peft_config = None
|
||||||
if lora_config:
|
if lora_config:
|
||||||
|
@ -371,93 +630,43 @@ class HFFinetuningSingleDevice:
|
||||||
target_modules=lora_config.lora_attn_modules,
|
target_modules=lora_config.lora_attn_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup training arguments
|
# Validate data config
|
||||||
lr = 2e-5
|
if not config.data_config:
|
||||||
if config.optimizer_config:
|
raise ValueError("DataConfig is required for training")
|
||||||
lr = config.optimizer_config.lr
|
|
||||||
|
|
||||||
# Calculate steps per epoch and appropriate intervals
|
# Train in a separate process
|
||||||
steps_per_epoch = len(train_dataset) // config.data_config.batch_size
|
logger.info("Starting training in separate process")
|
||||||
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
|
|
||||||
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
|
|
||||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
|
||||||
|
|
||||||
logger.info(f"Dataset size: {len(train_dataset)} examples")
|
|
||||||
logger.info(f"Batch size: {config.data_config.batch_size}")
|
|
||||||
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
|
||||||
logger.info(f"Will evaluate every {eval_steps} steps")
|
|
||||||
logger.info(f"Will save every {save_steps} steps")
|
|
||||||
logger.info(f"Will log every {logging_steps} steps")
|
|
||||||
|
|
||||||
# save_strategy should be none if output dir is none
|
|
||||||
save_strategy = "no"
|
|
||||||
if output_dir_path:
|
|
||||||
save_strategy = "steps"
|
|
||||||
training_arguments = SFTConfig(
|
|
||||||
max_steps=config.max_steps_per_epoch,
|
|
||||||
output_dir=str(output_dir_path) if output_dir_path is not None else None,
|
|
||||||
num_train_epochs=config.n_epochs,
|
|
||||||
per_device_train_batch_size=config.data_config.batch_size,
|
|
||||||
fp16=device.type == "cuda",
|
|
||||||
bf16=device.type != "cuda",
|
|
||||||
# use_cpu should only be set if we are on a "True" CPU machine, not a MPS enabled Mac due to stability issues.
|
|
||||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
|
||||||
save_strategy=save_strategy,
|
|
||||||
save_steps=save_steps,
|
|
||||||
report_to="none",
|
|
||||||
max_seq_length=provider_config.max_seq_length,
|
|
||||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
||||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
|
||||||
learning_rate=lr,
|
|
||||||
warmup_ratio=provider_config.warmup_ratio,
|
|
||||||
weight_decay=provider_config.weight_decay,
|
|
||||||
logging_steps=logging_steps,
|
|
||||||
# Enable validation
|
|
||||||
eval_strategy="steps",
|
|
||||||
eval_steps=eval_steps,
|
|
||||||
save_total_limit=provider_config.save_total_limit,
|
|
||||||
remove_unused_columns=False,
|
|
||||||
dataloader_pin_memory=provider_config.dataloader_pin_memory,
|
|
||||||
dataloader_num_workers=provider_config.dataloader_num_workers,
|
|
||||||
dataset_text_field="text",
|
|
||||||
packing=False,
|
|
||||||
# Add evaluation metrics
|
|
||||||
# loading the best model can only happen if we have saved a model
|
|
||||||
load_best_model_at_end=True if output_dir_path else False,
|
|
||||||
metric_for_best_model="eval_loss",
|
|
||||||
greater_is_better=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize trainer with both train and eval datasets
|
|
||||||
trainer = SFTTrainer(
|
|
||||||
model=model_obj,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
peft_config=peft_config,
|
|
||||||
args=training_arguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
logger.info("Starting training")
|
|
||||||
try:
|
try:
|
||||||
trainer.train()
|
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
|
||||||
|
if device.type in ["cuda", "mps"]:
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
process = multiprocessing.Process(
|
||||||
|
target=self._run_training_sync,
|
||||||
|
kwargs={
|
||||||
|
"model": model,
|
||||||
|
"provider_config": provider_config.model_dump(),
|
||||||
|
"peft_config": peft_config,
|
||||||
|
"config": config.model_dump(),
|
||||||
|
"output_dir_path": output_dir_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
|
||||||
|
# Monitor the process
|
||||||
|
while process.is_alive():
|
||||||
|
process.join(timeout=1) # Check every second
|
||||||
|
if not process.is_alive():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get the return code
|
||||||
|
if process.exitcode != 0:
|
||||||
|
raise RuntimeError(f"Training failed with exit code {process.exitcode}")
|
||||||
|
|
||||||
memory_stats["after_training"] = get_memory_stats(device)
|
memory_stats["after_training"] = get_memory_stats(device)
|
||||||
|
|
||||||
# Save final model
|
|
||||||
model_obj.config.use_cache = True
|
|
||||||
# if we have LoRA we need to do `merge_and_unload`
|
|
||||||
if lora_config:
|
|
||||||
model_obj = trainer.model.merge_and_unload()
|
|
||||||
else:
|
|
||||||
model_obj = trainer.model
|
|
||||||
|
|
||||||
checkpoint = None
|
|
||||||
checkpoints = None
|
checkpoints = None
|
||||||
# only save a final model if checkpoint dir is specified
|
|
||||||
# this is especially useful to test training rather than saving of checkpoints
|
|
||||||
if output_dir_path:
|
if output_dir_path:
|
||||||
model_obj.save_pretrained(output_dir_path / "merged_model")
|
|
||||||
|
|
||||||
# Create checkpoint
|
# Create checkpoint
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
identifier=f"{model}-sft-{config.n_epochs}",
|
identifier=f"{model}-sft-{config.n_epochs}",
|
||||||
|
@ -470,33 +679,5 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
return memory_stats, checkpoints
|
return memory_stats, checkpoints
|
||||||
finally:
|
finally:
|
||||||
# Clean up resources
|
|
||||||
if hasattr(trainer, "model"):
|
|
||||||
if device.type != "cpu":
|
|
||||||
trainer.model.to("cpu")
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
del trainer.model
|
|
||||||
del trainer
|
|
||||||
gc.collect()
|
|
||||||
memory_stats["final"] = get_memory_stats(device)
|
memory_stats["final"] = get_memory_stats(device)
|
||||||
|
gc.collect()
|
||||||
async def _setup_data(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Load dataset from llama stack dataset provider"""
|
|
||||||
try:
|
|
||||||
|
|
||||||
async def fetch_rows(dataset_id: str):
|
|
||||||
return await self.datasetio_api.iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
limit=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await fetch_rows(dataset_id)
|
|
||||||
if not isinstance(all_rows.data, list):
|
|
||||||
raise RuntimeError("Expected dataset data to be a list")
|
|
||||||
return all_rows.data
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
@ -47,6 +46,7 @@ from llama_stack.apis.post_training import (
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
TorchtuneCheckpointer,
|
TorchtuneCheckpointer,
|
||||||
|
@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice:
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
# clean up the memory after training finishes
|
# clean up the memory after training finishes
|
||||||
if self._device.type != "cpu":
|
evacuate_model_from_device(self._model, self._device.type)
|
||||||
self._model.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
del self._model
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return (memory_stats, checkpoints)
|
return (memory_stats, checkpoints)
|
||||||
|
|
||||||
|
|
137
requirements.txt
137
requirements.txt
|
@ -1,206 +1,69 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
# This file was autogenerated by uv via the following command:
|
||||||
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
|
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
# via pydantic
|
|
||||||
anyio==4.8.0
|
anyio==4.8.0
|
||||||
# via
|
|
||||||
# httpx
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
attrs==25.1.0
|
attrs==25.1.0
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
blobfile==3.0.0
|
blobfile==3.0.0
|
||||||
# via llama-stack
|
|
||||||
cachetools==5.5.2
|
cachetools==5.5.2
|
||||||
# via google-auth
|
|
||||||
certifi==2025.1.31
|
certifi==2025.1.31
|
||||||
# via
|
|
||||||
# httpcore
|
|
||||||
# httpx
|
|
||||||
# kubernetes
|
|
||||||
# requests
|
|
||||||
charset-normalizer==3.4.1
|
charset-normalizer==3.4.1
|
||||||
# via requests
|
|
||||||
click==8.1.8
|
click==8.1.8
|
||||||
# via llama-stack-client
|
|
||||||
colorama==0.4.6 ; sys_platform == 'win32'
|
colorama==0.4.6 ; sys_platform == 'win32'
|
||||||
# via
|
|
||||||
# click
|
|
||||||
# tqdm
|
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
# via
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
durationpy==0.9
|
durationpy==0.9
|
||||||
# via kubernetes
|
|
||||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||||
# via anyio
|
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
# via
|
|
||||||
# blobfile
|
|
||||||
# huggingface-hub
|
|
||||||
fire==0.7.0
|
fire==0.7.0
|
||||||
# via llama-stack
|
|
||||||
fsspec==2024.12.0
|
fsspec==2024.12.0
|
||||||
# via huggingface-hub
|
|
||||||
google-auth==2.38.0
|
google-auth==2.38.0
|
||||||
# via kubernetes
|
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
# via
|
|
||||||
# httpcore
|
|
||||||
# llama-stack
|
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
# via httpx
|
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
huggingface-hub==0.29.0
|
huggingface-hub==0.29.0
|
||||||
# via llama-stack
|
|
||||||
idna==3.10
|
idna==3.10
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# httpx
|
|
||||||
# requests
|
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
# via llama-stack
|
|
||||||
jiter==0.8.2
|
jiter==0.8.2
|
||||||
# via openai
|
|
||||||
jsonschema==4.23.0
|
jsonschema==4.23.0
|
||||||
# via llama-stack
|
|
||||||
jsonschema-specifications==2024.10.1
|
jsonschema-specifications==2024.10.1
|
||||||
# via jsonschema
|
|
||||||
kubernetes==32.0.1
|
kubernetes==32.0.1
|
||||||
# via llama-stack
|
|
||||||
llama-stack-client==0.2.7
|
llama-stack-client==0.2.7
|
||||||
# via llama-stack
|
|
||||||
lxml==5.3.1
|
lxml==5.3.1
|
||||||
# via blobfile
|
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
# via rich
|
|
||||||
markupsafe==3.0.2
|
markupsafe==3.0.2
|
||||||
# via jinja2
|
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
|
||||||
numpy==2.2.3
|
numpy==2.2.3
|
||||||
# via pandas
|
|
||||||
oauthlib==3.2.2
|
oauthlib==3.2.2
|
||||||
# via
|
|
||||||
# kubernetes
|
|
||||||
# requests-oauthlib
|
|
||||||
openai==1.71.0
|
openai==1.71.0
|
||||||
# via llama-stack
|
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
# via huggingface-hub
|
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
# via llama-stack-client
|
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
# via llama-stack
|
|
||||||
prompt-toolkit==3.0.50
|
prompt-toolkit==3.0.50
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
pyaml==25.1.0
|
pyaml==25.1.0
|
||||||
# via llama-stack-client
|
|
||||||
pyasn1==0.6.1
|
pyasn1==0.6.1
|
||||||
# via
|
|
||||||
# pyasn1-modules
|
|
||||||
# rsa
|
|
||||||
pyasn1-modules==0.4.2
|
pyasn1-modules==0.4.2
|
||||||
# via google-auth
|
|
||||||
pycryptodomex==3.21.0
|
pycryptodomex==3.21.0
|
||||||
# via blobfile
|
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
pydantic-core==2.27.2
|
pydantic-core==2.27.2
|
||||||
# via pydantic
|
|
||||||
pygments==2.19.1
|
pygments==2.19.1
|
||||||
# via rich
|
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via
|
|
||||||
# kubernetes
|
|
||||||
# pandas
|
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
# via llama-stack
|
|
||||||
pytz==2025.1
|
pytz==2025.1
|
||||||
# via pandas
|
|
||||||
pyyaml==6.0.2
|
pyyaml==6.0.2
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# kubernetes
|
|
||||||
# pyaml
|
|
||||||
referencing==0.36.2
|
referencing==0.36.2
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# jsonschema-specifications
|
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
# via tiktoken
|
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# kubernetes
|
|
||||||
# llama-stack
|
|
||||||
# requests-oauthlib
|
|
||||||
# tiktoken
|
|
||||||
requests-oauthlib==2.0.0
|
requests-oauthlib==2.0.0
|
||||||
# via kubernetes
|
|
||||||
rich==13.9.4
|
rich==13.9.4
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
rpds-py==0.22.3
|
rpds-py==0.22.3
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
rsa==4.9
|
rsa==4.9
|
||||||
# via google-auth
|
|
||||||
setuptools==75.8.0
|
setuptools==75.8.0
|
||||||
# via llama-stack
|
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
# via
|
|
||||||
# kubernetes
|
|
||||||
# python-dateutil
|
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
termcolor==2.5.0
|
termcolor==2.5.0
|
||||||
# via
|
|
||||||
# fire
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
tiktoken==0.9.0
|
tiktoken==0.9.0
|
||||||
# via llama-stack
|
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# huggingface-hub
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
# pydantic
|
|
||||||
# pydantic-core
|
|
||||||
# referencing
|
|
||||||
# rich
|
|
||||||
tzdata==2025.1
|
tzdata==2025.1
|
||||||
# via pandas
|
|
||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
# via
|
|
||||||
# blobfile
|
|
||||||
# kubernetes
|
|
||||||
# requests
|
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
# via prompt-toolkit
|
|
||||||
websocket-client==1.8.0
|
websocket-client==1.8.0
|
||||||
# via kubernetes
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue