temp commit

This commit is contained in:
Botao Chen 2024-11-27 16:46:29 -08:00
parent bfc782c054
commit 6c709abc4d
3 changed files with 28 additions and 54 deletions

View file

@ -180,16 +180,16 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune") @webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune( def supervised_fine_tune(
self, self,
job_uuid: Optional[str], job_uuid: str,
model: Optional[str], model: str,
dataset_id: Optional[str], dataset_id: str,
validation_dataset_id: Optional[str], validation_dataset_id: str,
algorithm: Optional[FinetuningAlgorithm], algorithm: FinetuningAlgorithm,
algorithm_config: Optional[LoraFinetuningConfig], algorithm_config: LoraFinetuningConfig,
optimizer_config: Optional[OptimizerConfig], optimizer_config: OptimizerConfig,
training_config: Optional[TrainingConfig], training_config: TrainingConfig,
hyperparam_search_config: Optional[Dict[str, Any]], hyperparam_search_config: Dict[str, Any],
logger_config: Optional[Dict[str, Any]], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize")

View file

@ -20,46 +20,18 @@ class MetaReferencePostTrainingImpl:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
LoraFinetuningConfig(
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)
OptimizerConfig(
optimizer_type=OptimizerType.adamw,
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
TrainingConfig(
dtype="bf16",
n_epochs=1,
max_steps_per_epoch=10,
gradient_accumulation_steps=1,
batch_size=1,
shuffle=1,
enable_activation_checkpointing=False,
memory_efficient_fsdp_wrap=False,
fsdp_cpu_offload=False,
)
def supervised_fine_tune( def supervised_fine_tune(
self, self,
job_uuid: Optional[str] = "1234", job_uuid: str,
model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct", model: str,
dataset_id: Optional[str] = "alpaca", dataset_id: str,
validation_dataset_id: Optional[str] = "alpaca", validation_dataset_id: str,
algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora, algorithm: FinetuningAlgorithm,
algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig, algorithm_config: LoraFinetuningConfig,
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig, optimizer_config: OptimizerConfig,
training_config: Optional[TrainingConfig] = TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Optional[Dict[str, Any]] = {}, hyperparam_search_config: Dict[str, Any],
logger_config: Optional[Dict[str, Any]] = {}, logger_config: Dict[str, Any],
) -> PostTrainingJob: ) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest( request = PostTrainingSFTRequest(
@ -71,6 +43,7 @@ class MetaReferencePostTrainingImpl:
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
optimizer_config=optimizer_config, optimizer_config=optimizer_config,
training_config=training_config, training_config=training_config,
hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config, logger_config=logger_config,
) )
if request.algorithm == FinetuningAlgorithm.lora: if request.algorithm == FinetuningAlgorithm.lora:

View file

@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from torch import nn from torch import nn
from torchtune import utils as torchtune_utils
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.apis.post_training import PostTrainingSFTRequest
@ -56,14 +57,14 @@ class LoraFinetuningSingleDevice:
# self._device = utils.get_device(device=cfg.device) # self._device = utils.get_device(device=cfg.device)
self.config = config self.config = config
self.request = request self.request = request
self._device = training.utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype( self._dtype = training.get_dtype(
request.training_config.dtype, device=self._device request.training_config.dtype, device=self._device
) )
self.model_id = request.model self.model_id = config.model
# hardcode it for now and see how it works with get_training_job_artifacts # hardcode it for now and see how it works with get_training_job_artifacts
self._output_dir = f"~/.llama/checkpoints/post_training/{request.model_id}" self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
self._log_every_n_steps = 1 self._log_every_n_steps = 1
self._log_peak_memory_stats = False self._log_peak_memory_stats = False
@ -111,8 +112,8 @@ class LoraFinetuningSingleDevice:
return [f"Error: The directory '{checkpoint_dir}' does not exist."] return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer( self._checkpointer = training.FullModelMetaCheckpointer(
checkpoint_dir=self.config.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files, checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir, output_dir=self._output_dir,
# todo: automatically get this info from model # todo: automatically get this info from model
model_type="LLAMA3", model_type="LLAMA3",
@ -449,7 +450,7 @@ class LoraFinetuningSingleDevice:
# ): # ):
# torch.cuda.memory._record_memory_history() # torch.cuda.memory._record_memory_history()
training.utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch # Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step # and increment the total number of tokens seen in the step