mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
refine
This commit is contained in:
parent
dc3d9d7720
commit
8132b4e177
3 changed files with 34 additions and 31 deletions
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -25,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metric: PostTrainingMetric
|
training_metric: Optional[PostTrainingMetric] = None
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.providers.inline.post_training.torchtune.common.checkpointer im
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
@ -186,12 +187,15 @@ class LoraFinetuningSingleDevice:
|
||||||
shuffle=self._shuffle,
|
shuffle=self._shuffle,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
_, self._validation_dataloader = await self._setup_data(
|
_, self._validation_dataloader = await self._setup_data(
|
||||||
dataset_id=self.training_config.data_config.validation_dataset_id,
|
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||||
tokenizer=self._tokenizer,
|
tokenizer=self._tokenizer,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Dataset and Sampler are initialized.")
|
log.info("Dataset and Sampler are initialized.")
|
||||||
|
|
||||||
# Number of training steps in each epoch depends on the number of batches produced
|
# Number of training steps in each epoch depends on the number of batches produced
|
||||||
|
@ -461,6 +465,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._training_sampler.set_epoch(curr_epoch)
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
loss_to_log = 0.0
|
loss_to_log = 0.0
|
||||||
|
|
||||||
|
pbar = tqdm(total=self._steps_per_epoch)
|
||||||
for idx, batch in enumerate(self._training_dataloader):
|
for idx, batch in enumerate(self._training_dataloader):
|
||||||
if (
|
if (
|
||||||
self.max_steps_per_epoch is not None
|
self.max_steps_per_epoch is not None
|
||||||
|
@ -498,6 +503,12 @@ class LoraFinetuningSingleDevice:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
loss_to_log = running_loss.item() / num_tokens
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(
|
||||||
|
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||||
|
)
|
||||||
|
|
||||||
time_per_step = time.perf_counter() - t0
|
time_per_step = time.perf_counter() - t0
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"loss": loss_to_log,
|
"loss": loss_to_log,
|
||||||
|
@ -524,28 +535,31 @@ class LoraFinetuningSingleDevice:
|
||||||
self.epochs_run += 1
|
self.epochs_run += 1
|
||||||
log.info("Starting checkpoint save...")
|
log.info("Starting checkpoint save...")
|
||||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
validation_loss, perplexity = await self.validate()
|
|
||||||
training_metreic = PostTrainingMetric(
|
|
||||||
epoch=curr_epoch,
|
|
||||||
train_loss=loss_to_log,
|
|
||||||
validation_loss=validation_loss,
|
|
||||||
perplexity=perplexity,
|
|
||||||
)
|
|
||||||
checkpoint = Checkpoint(
|
checkpoint = Checkpoint(
|
||||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
epoch=curr_epoch,
|
epoch=curr_epoch,
|
||||||
post_training_job_id=self.job_uuid,
|
post_training_job_id=self.job_uuid,
|
||||||
path=checkpoint_path,
|
path=checkpoint_path,
|
||||||
training_metric=training_metreic,
|
|
||||||
)
|
)
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
|
validation_loss, perplexity = await self.validation()
|
||||||
|
training_metreic = PostTrainingMetric(
|
||||||
|
epoch=curr_epoch,
|
||||||
|
train_loss=loss_to_log,
|
||||||
|
validation_loss=validation_loss,
|
||||||
|
perplexity=perplexity,
|
||||||
|
)
|
||||||
|
checkpoint.training_metric = training_metreic
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
return (memory_stats, checkpoints)
|
return (memory_stats, checkpoints)
|
||||||
|
|
||||||
async def validate(self) -> Tuple[float, float]:
|
async def validation(self) -> Tuple[float, float]:
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
log.info("Starting validation...")
|
||||||
|
pbar = tqdm(total=len(self._validation_dataloader))
|
||||||
for idx, batch in enumerate(self._validation_dataloader):
|
for idx, batch in enumerate(self._validation_dataloader):
|
||||||
if idx == 10:
|
if idx == 10:
|
||||||
break
|
break
|
||||||
|
@ -562,6 +576,9 @@ class LoraFinetuningSingleDevice:
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
total_tokens += num_tokens
|
total_tokens += num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"validation step: {idx}")
|
||||||
|
|
||||||
mean_loss = total_loss / total_tokens
|
mean_loss = total_loss / total_tokens
|
||||||
perplexity = torch.exp(torch.tensor(mean_loss))
|
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||||
|
|
||||||
|
|
|
@ -49,20 +49,5 @@ datasets:
|
||||||
type: string
|
type: string
|
||||||
text:
|
text:
|
||||||
type: string
|
type: string
|
||||||
- dataset_id: alpaca_eval
|
|
||||||
provider_id: huggingface-0
|
|
||||||
url:
|
|
||||||
uri: https://huggingface.co/datasets/causal-lm/code_alpaca
|
|
||||||
metadata:
|
|
||||||
path: causal-lm/code_alpaca
|
|
||||||
name:
|
|
||||||
split: validation
|
|
||||||
dataset_schema:
|
|
||||||
instruction:
|
|
||||||
type: string
|
|
||||||
input:
|
|
||||||
type: string
|
|
||||||
output:
|
|
||||||
type: string
|
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue