mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
init commit
This commit is contained in:
parent
e5993c565e
commit
dc3d9d7720
3 changed files with 68 additions and 10 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
@ -26,4 +25,4 @@ class Checkpoint(BaseModel):
|
|||
epoch: int
|
||||
post_training_job_id: str
|
||||
path: str
|
||||
training_metrics: Optional[PostTrainingMetric] = None
|
||||
training_metric: PostTrainingMetric
|
||||
|
|
|
@ -180,11 +180,18 @@ class LoraFinetuningSingleDevice:
|
|||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||
log.info("Loss is initialized.")
|
||||
|
||||
self._sampler, self._dataloader = await self._setup_data(
|
||||
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=self._shuffle,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
_, self._validation_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=False,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
log.info("Dataset and Sampler are initialized.")
|
||||
|
||||
# Number of training steps in each epoch depends on the number of batches produced
|
||||
|
@ -192,7 +199,7 @@ class LoraFinetuningSingleDevice:
|
|||
# for logging and tracking training state. This should be computed after the dataloader
|
||||
# has been setup
|
||||
self._steps_per_epoch = (
|
||||
len(self._dataloader) // self._gradient_accumulation_steps
|
||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
)
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
|
@ -311,15 +318,19 @@ class LoraFinetuningSingleDevice:
|
|||
return optimizer
|
||||
|
||||
async def _setup_data(
|
||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
||||
self,
|
||||
dataset_id: str,
|
||||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows():
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows()
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support alpaca instruct dataset
|
||||
|
@ -447,9 +458,10 @@ class LoraFinetuningSingleDevice:
|
|||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||
)
|
||||
self._sampler.set_epoch(curr_epoch)
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
for idx, batch in enumerate(self._dataloader):
|
||||
for idx, batch in enumerate(self._training_dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps)
|
||||
|
@ -512,13 +524,45 @@ class LoraFinetuningSingleDevice:
|
|||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
validation_loss, perplexity = await self.validate()
|
||||
training_metreic = PostTrainingMetric(
|
||||
epoch=curr_epoch,
|
||||
train_loss=loss_to_log,
|
||||
validation_loss=validation_loss,
|
||||
perplexity=perplexity,
|
||||
)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
training_metric=training_metreic,
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validate(self) -> Tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == 10:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
loss = await self._loss_step(batch) * num_tokens
|
||||
|
||||
total_loss += loss
|
||||
total_tokens += num_tokens
|
||||
|
||||
mean_loss = total_loss / total_tokens
|
||||
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||
|
||||
return mean_loss, perplexity.item()
|
||||
|
|
|
@ -49,5 +49,20 @@ datasets:
|
|||
type: string
|
||||
text:
|
||||
type: string
|
||||
- dataset_id: alpaca_eval
|
||||
provider_id: huggingface-0
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/causal-lm/code_alpaca
|
||||
metadata:
|
||||
path: causal-lm/code_alpaca
|
||||
name:
|
||||
split: validation
|
||||
dataset_schema:
|
||||
instruction:
|
||||
type: string
|
||||
input:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue