temp commit

This commit is contained in:
Botao Chen 2024-11-27 16:46:29 -08:00
parent bfc782c054
commit 6c709abc4d
3 changed files with 28 additions and 54 deletions

View file

@ -180,16 +180,16 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
self,
job_uuid: Optional[str],
model: Optional[str],
dataset_id: Optional[str],
validation_dataset_id: Optional[str],
algorithm: Optional[FinetuningAlgorithm],
algorithm_config: Optional[LoraFinetuningConfig],
optimizer_config: Optional[OptimizerConfig],
training_config: Optional[TrainingConfig],
hyperparam_search_config: Optional[Dict[str, Any]],
logger_config: Optional[Dict[str, Any]],
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")

View file

@ -20,46 +20,18 @@ class MetaReferencePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
LoraFinetuningConfig(
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)
OptimizerConfig(
optimizer_type=OptimizerType.adamw,
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
TrainingConfig(
dtype="bf16",
n_epochs=1,
max_steps_per_epoch=10,
gradient_accumulation_steps=1,
batch_size=1,
shuffle=1,
enable_activation_checkpointing=False,
memory_efficient_fsdp_wrap=False,
fsdp_cpu_offload=False,
)
def supervised_fine_tune(
self,
job_uuid: Optional[str] = "1234",
model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct",
dataset_id: Optional[str] = "alpaca",
validation_dataset_id: Optional[str] = "alpaca",
algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora,
algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig,
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig,
training_config: Optional[TrainingConfig] = TrainingConfig,
hyperparam_search_config: Optional[Dict[str, Any]] = {},
logger_config: Optional[Dict[str, Any]] = {},
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest(
@ -71,6 +43,7 @@ class MetaReferencePostTrainingImpl:
algorithm_config=algorithm_config,
optimizer_config=optimizer_config,
training_config=training_config,
hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config,
)
if request.algorithm == FinetuningAlgorithm.lora:

View file

@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from torchtune import utils as torchtune_utils
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
@ -56,14 +57,14 @@ class LoraFinetuningSingleDevice:
# self._device = utils.get_device(device=cfg.device)
self.config = config
self.request = request
self._device = training.utils.get_device(device="cuda")
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(
request.training_config.dtype, device=self._device
)
self.model_id = request.model
self.model_id = config.model
# hardcode it for now and see how it works with get_training_job_artifacts
self._output_dir = f"~/.llama/checkpoints/post_training/{request.model_id}"
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
self._log_every_n_steps = 1
self._log_peak_memory_stats = False
@ -111,8 +112,8 @@ class LoraFinetuningSingleDevice:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
checkpoint_dir=self.config.checkpoint_dir,
checkpoint_files=get_checkpoint_files,
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
# todo: automatically get this info from model
model_type="LLAMA3",
@ -449,7 +450,7 @@ class LoraFinetuningSingleDevice:
# ):
# torch.cuda.memory._record_memory_history()
training.utils.batch_to_device(batch, self._device)
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step