mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
d7598c68d7
commit
9a976bcabd
9 changed files with 310 additions and 23 deletions
|
@ -9,7 +9,7 @@ from typing import Any, List, Optional, Protocol
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_stack.apis import post_training
|
||||
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
|
@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -28,6 +29,7 @@ class Api(Enum):
|
|||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferencePostTrainingConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .post_training import MetaReferencePostTrainingImpl
|
||||
|
||||
impl = MetaReferencePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
)
|
||||
# await impl.initialize()
|
||||
return impl
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field,
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MetaReferencePostTrainingConfig(BaseModel):
|
||||
|
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||
from torchtune.data._messages import validate_messages
|
||||
|
|
|
@ -1,6 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
|
||||
class MetaReferencePostTrainingImpl:
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
def __init__(
|
||||
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
|
@ -27,7 +45,12 @@ class MetaReferencePostTrainingImpl:
|
|||
logger_config=logger_config,
|
||||
)
|
||||
if request.algorithm == FinetuningAlgorithm.lora:
|
||||
recipe = LoraFinetuningRecipeSingleDevice(self.config, request)
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config, request, self.datasetio_api
|
||||
)
|
||||
recipe.setup(self.config)
|
||||
recipe.train()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
|
|
@ -3,15 +3,21 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from torch import nn
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
|
||||
from llama_stack.providers.inline.post_training.meta_reference import utils
|
||||
from llama_stack.providers.inline.post_training.meta_reference.configs import (
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
|
||||
|
@ -36,7 +42,7 @@ from torchtune.training.lr_scheduler import get_cosine_schedule_with_warmup
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
Tokenizer = Union[Llama3Tokenizer]
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
@ -44,13 +50,13 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
request: PostTrainingSFTRequest,
|
||||
datasetio_api: DatasetIOAPI,
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
# to make user config easier, assume the device is 'cuda' only
|
||||
# self._device = utils.get_device(device=cfg.device)
|
||||
self.config = config
|
||||
self.request = request
|
||||
self._device = "cuda"
|
||||
self._device = training.utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(
|
||||
request.training_config.dtype, device=self._device
|
||||
)
|
||||
|
@ -68,6 +74,10 @@ class LoraFinetuningSingleDevice:
|
|||
self._shuffle = request.training_config.shuffle
|
||||
self._batch_size = request.training_config.batch_size
|
||||
|
||||
self.checkpoint_dir = (
|
||||
self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}"
|
||||
)
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
@ -98,7 +108,7 @@ class LoraFinetuningSingleDevice:
|
|||
pth_files = [file for file in files if file.endswith(".pth")]
|
||||
return pth_files
|
||||
except FileNotFoundError:
|
||||
return f"Error: The directory '{checkpoint_dir}' does not exist."
|
||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||
|
||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
||||
checkpoint_dir=self.config.checkpoint_dir,
|
||||
|
@ -133,7 +143,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Tokenizer is initialized from file.")
|
||||
|
||||
self._optimizer = self._setup_optimizer(
|
||||
optimizer_config=self.request.training_config.optimizer, opt_state_dict=None
|
||||
optimizer_config=self.request.training_config.optimizer
|
||||
)
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
|
@ -253,8 +263,8 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
def _setup_tokenizer(
|
||||
self,
|
||||
) -> Tokenizer:
|
||||
tokenizer_path = self.config.checkpoint_dir + "/tokenizer.model"
|
||||
) -> Llama3Tokenizer:
|
||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||
tokenizer_type = utils.get_tokenizer_type(self.model_id)
|
||||
return tokenizer_type(path=tokenizer_path)
|
||||
|
||||
|
@ -270,13 +280,17 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Optimizer and loss are initialized.")
|
||||
return optimizer
|
||||
|
||||
async def _setup_data(
|
||||
self, tokenizer: Tokenizer, shuffle: bool, batch_size: int
|
||||
def _setup_data(
|
||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.request.dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
async def fetch_rows():
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.request.dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
# Run the async function in an event loop
|
||||
all_rows = asyncio.run(fetch_rows())
|
||||
rows = all_rows.rows
|
||||
|
||||
ds = SFTDataset(
|
||||
|
@ -323,3 +337,198 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
log.info("Learning rate scheduler is initialized.")
|
||||
return lr_scheduler
|
||||
|
||||
def save_checkpoint(self, epoch: int) -> None:
|
||||
"""
|
||||
Checkpoint the state of the recipe. The constructed checkpoint state dict
|
||||
contains the following information:
|
||||
- Merged weights with key MODEL_KEY
|
||||
- Adapter weights with key ADAPTER_KEY
|
||||
- Relevant recipe state if training is not complete
|
||||
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
|
||||
|
||||
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
|
||||
"""
|
||||
ckpt_dict = {}
|
||||
|
||||
intermediate_checkpoint = epoch + 1 < self.total_epochs
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
|
||||
|
||||
# Construct the full state dict with LoRA weights merged into base LLM weights
|
||||
|
||||
# Move to CPU to avoid a copy on GPU
|
||||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
merged_state_dict = get_merged_lora_ckpt(
|
||||
state_dict,
|
||||
rank=self._lora_rank,
|
||||
alpha=self._lora_alpha,
|
||||
)
|
||||
|
||||
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
|
||||
|
||||
adapter_config = {
|
||||
"r": self._lora_rank,
|
||||
"lora_alpha": self._lora_alpha,
|
||||
"target_modules": get_lora_module_names(
|
||||
self._lora_attn_modules,
|
||||
self._apply_lora_to_mlp,
|
||||
self._apply_lora_to_output,
|
||||
),
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
intermediate_checkpoint=intermediate_checkpoint,
|
||||
)
|
||||
|
||||
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
with self.activations_handling_ctx:
|
||||
logits = self._model(**batch)
|
||||
|
||||
# Shift labels to compute loss
|
||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||
labels = torch.hstack(
|
||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
||||
)
|
||||
if not isinstance(logits, list):
|
||||
labels = labels.reshape(-1)
|
||||
logits = logits.reshape(-1, logits.size(-1))
|
||||
|
||||
loss = self._loss_fn(logits, labels)
|
||||
|
||||
# free logits otherwise it peaks backward memory
|
||||
del logits
|
||||
|
||||
return loss
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
||||
# if self._compile:
|
||||
# log.info(
|
||||
# "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
|
||||
# )
|
||||
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
# t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
self._sampler.set_epoch(curr_epoch)
|
||||
|
||||
# pbar = tqdm(total=self._steps_per_epoch)
|
||||
for idx, batch in enumerate(self._dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps)
|
||||
== self.max_steps_per_epoch
|
||||
):
|
||||
break
|
||||
|
||||
# Start tracking CUDA memory for active steps for just the first epoch
|
||||
# if (
|
||||
# curr_epoch == 0
|
||||
# and self.profiler_profile_memory
|
||||
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
|
||||
# ):
|
||||
# torch.cuda.memory._record_memory_history()
|
||||
|
||||
training.utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
current_num_tokens = (
|
||||
batch["labels"] != self._loss_fn.ignore_index
|
||||
).sum()
|
||||
num_tokens += current_num_tokens
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
self._optimizer.step()
|
||||
self._optimizer.zero_grad(set_to_none=True)
|
||||
self._lr_scheduler.step()
|
||||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
# loss_to_log = running_loss.item() / num_tokens
|
||||
# pbar.update(1)
|
||||
# pbar.set_description(
|
||||
# f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||
# )
|
||||
|
||||
# Log per-step metrics
|
||||
# if self.global_step % self._log_every_n_steps == 0:
|
||||
# time_per_step = time.perf_counter() - t0
|
||||
# log_dict = {
|
||||
# "loss": loss_to_log,
|
||||
# "lr": self._optimizer.param_groups[0]["lr"],
|
||||
# "tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
# }
|
||||
# if self._device.type == "cuda" and self._log_peak_memory_stats:
|
||||
# log_dict.update(
|
||||
# training.get_memory_stats(device=self._device)
|
||||
# )
|
||||
# if self._clip_grad_norm is not None:
|
||||
# log_dict.update({"grad_norm": grad_norm})
|
||||
# self._metric_logger.log_dict(
|
||||
# log_dict,
|
||||
# step=self.global_step,
|
||||
# )
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
# t0 = time.perf_counter()
|
||||
|
||||
# Stop tracking CUDA memory now that active steps are complete
|
||||
# if (
|
||||
# curr_epoch == 0
|
||||
# and self.profiler_profile_memory
|
||||
# and idx
|
||||
# == self.profiler_wait_steps
|
||||
# + self.profiler_warmup_steps
|
||||
# + self.profiler_active_steps
|
||||
# ):
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
# Step the profiler
|
||||
# Note we are stepping each batch, which might not include optimizer step in the trace
|
||||
# if the schedule cycle doesn't align with gradient accumulation.
|
||||
# prof.step()
|
||||
|
||||
self.epochs_run += 1
|
||||
# start_save_checkpoint = time.perf_counter()
|
||||
log.info("Starting checkpoint save...")
|
||||
self.save_checkpoint(epoch=curr_epoch)
|
||||
# log.info(
|
||||
# "Checkpoint saved in {:.2f} seconds.".format(
|
||||
# time.perf_counter() - start_save_checkpoint
|
||||
# )
|
||||
# )
|
||||
|
|
|
@ -30,15 +30,13 @@ BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
|||
|
||||
|
||||
def get_model_type(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = resolve_model(model_id)
|
||||
return LORA_MODEL_TYPES[model.core_model_id.value]
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = resolve_model(model_id)
|
||||
|
|
28
llama_stack/providers/registry/post_training.py
Normal file
28
llama_stack/providers/registry/post_training.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
"torch",
|
||||
"torchtune",
|
||||
"numpy",
|
||||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.inline.post_training.meta_reference",
|
||||
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
|
||||
),
|
||||
]
|
|
@ -4,6 +4,8 @@ distribution_spec:
|
|||
description: Use Meta Reference for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
post_training:
|
||||
- inline::meta-reference
|
||||
inference:
|
||||
- inline::meta-reference
|
||||
memory:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue