mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-26 10:02:00 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -44,7 +44,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
|
|
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
messages = json.loads(row["chat_completion_input"])
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
if "content" not in messages[0]:
|
||||
logger.warning(f"Message missing content: {messages[0]}")
|
||||
log.warning(f"Message missing content: {messages[0]}")
|
||||
return None, None
|
||||
return messages[0]["content"], row["expected_answer"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
|
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
dialog = json.loads(row["dialog"])
|
||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
return None, None
|
||||
if dialog[0].get("role") != "user":
|
||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
||||
log.warning(f"First message must be from user: {dialog[0]}")
|
||||
return None, None
|
||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||
logger.warning("Dialog must have at least one assistant message")
|
||||
log.warning("Dialog must have at least one assistant message")
|
||||
return None, None
|
||||
|
||||
# Convert to human/gpt format
|
||||
|
|
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
|
|||
conversations = []
|
||||
for msg in dialog:
|
||||
if "role" not in msg or "content" not in msg:
|
||||
logger.warning(f"Message missing role or content: {msg}")
|
||||
log.warning(f"Message missing role or content: {msg}")
|
||||
continue
|
||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||
|
||||
# Format as a single conversation
|
||||
return conversations[0]["value"], conversations[1]["value"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
log.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
logger.info("Starting training process with async wrapper")
|
||||
log.info("Starting training process with async wrapper")
|
||||
asyncio.run(
|
||||
self._run_training(
|
||||
model=model,
|
||||
|
|
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
log.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
log.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
log.info(f"Initializing tokenizer for model: {model}")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
|
||||
|
|
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
|
|||
# This ensures consistent sequence lengths across the training process
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
|
||||
logger.info("Tokenizer initialized successfully")
|
||||
log.info("Tokenizer initialized successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
log.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
log.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
|
|
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
|
|||
Returns:
|
||||
Configured SFTConfig object
|
||||
"""
|
||||
logger.info("Configuring training arguments")
|
||||
log.info("Configuring training arguments")
|
||||
lr = 2e-5
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
logger.info(f"Using custom learning rate: {lr}")
|
||||
log.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
|
|
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
|
|||
peft_config: Optional LoRA configuration
|
||||
output_dir_path: Path to save the model
|
||||
"""
|
||||
logger.info("Saving final model")
|
||||
log.info("Saving final model")
|
||||
model_obj.config.use_cache = True
|
||||
|
||||
if peft_config:
|
||||
logger.info("Merging LoRA weights with base model")
|
||||
log.info("Merging LoRA weights with base model")
|
||||
model_obj = trainer.model.merge_and_unload()
|
||||
else:
|
||||
model_obj = trainer.model
|
||||
|
||||
save_path = output_dir_path / "merged_model"
|
||||
logger.info(f"Saving model to {save_path}")
|
||||
log.info(f"Saving model to {save_path}")
|
||||
model_obj.save_pretrained(save_path)
|
||||
|
||||
async def _run_training(
|
||||
|
|
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
|
|||
setup_signal_handlers()
|
||||
|
||||
# Convert config dicts back to objects
|
||||
logger.info("Initializing configuration objects")
|
||||
log.info("Initializing configuration objects")
|
||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||
config_obj = TrainingConfig(**config)
|
||||
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config_obj.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
|
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
|
|||
model_obj = load_model(model, device, provider_config_obj)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing SFTTrainer")
|
||||
log.info("Initializing SFTTrainer")
|
||||
trainer = SFTTrainer(
|
||||
model=model_obj,
|
||||
train_dataset=train_dataset,
|
||||
|
|
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
try:
|
||||
# Train
|
||||
logger.info("Starting training")
|
||||
log.info("Starting training")
|
||||
trainer.train()
|
||||
logger.info("Training completed successfully")
|
||||
log.info("Training completed successfully")
|
||||
|
||||
# Save final model if output directory is provided
|
||||
if output_dir_path:
|
||||
|
|
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
finally:
|
||||
# Clean up resources
|
||||
logger.info("Cleaning up resources")
|
||||
log.info("Cleaning up resources")
|
||||
if hasattr(trainer, "model"):
|
||||
evacuate_model_from_device(trainer.model, device.type)
|
||||
del trainer
|
||||
gc.collect()
|
||||
logger.info("Cleanup completed")
|
||||
log.info("Cleanup completed")
|
||||
|
||||
async def train(
|
||||
self,
|
||||
|
|
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
|
|||
"""Train a model using HuggingFace's SFTTrainer"""
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
|
|
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting training in separate process")
|
||||
log.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
if device.type in ["cuda", "mps"]:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -40,7 +40,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
|
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue