mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
temp commit
This commit is contained in:
parent
bfc782c054
commit
6c709abc4d
3 changed files with 28 additions and 54 deletions
|
@ -180,16 +180,16 @@ class PostTraining(Protocol):
|
|||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: Optional[str],
|
||||
model: Optional[str],
|
||||
dataset_id: Optional[str],
|
||||
validation_dataset_id: Optional[str],
|
||||
algorithm: Optional[FinetuningAlgorithm],
|
||||
algorithm_config: Optional[LoraFinetuningConfig],
|
||||
optimizer_config: Optional[OptimizerConfig],
|
||||
training_config: Optional[TrainingConfig],
|
||||
hyperparam_search_config: Optional[Dict[str, Any]],
|
||||
logger_config: Optional[Dict[str, Any]],
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
|
|
|
@ -20,46 +20,18 @@ class MetaReferencePostTrainingImpl:
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
LoraFinetuningConfig(
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
rank=8,
|
||||
alpha=16,
|
||||
)
|
||||
|
||||
OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adamw,
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
TrainingConfig(
|
||||
dtype="bf16",
|
||||
n_epochs=1,
|
||||
max_steps_per_epoch=10,
|
||||
gradient_accumulation_steps=1,
|
||||
batch_size=1,
|
||||
shuffle=1,
|
||||
enable_activation_checkpointing=False,
|
||||
memory_efficient_fsdp_wrap=False,
|
||||
fsdp_cpu_offload=False,
|
||||
)
|
||||
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: Optional[str] = "1234",
|
||||
model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct",
|
||||
dataset_id: Optional[str] = "alpaca",
|
||||
validation_dataset_id: Optional[str] = "alpaca",
|
||||
algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora,
|
||||
algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig,
|
||||
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig,
|
||||
training_config: Optional[TrainingConfig] = TrainingConfig,
|
||||
hyperparam_search_config: Optional[Dict[str, Any]] = {},
|
||||
logger_config: Optional[Dict[str, Any]] = {},
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = PostTrainingSFTRequest(
|
||||
|
@ -71,6 +43,7 @@ class MetaReferencePostTrainingImpl:
|
|||
algorithm_config=algorithm_config,
|
||||
optimizer_config=optimizer_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config=hyperparam_search_config,
|
||||
logger_config=logger_config,
|
||||
)
|
||||
if request.algorithm == FinetuningAlgorithm.lora:
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import torch
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
|
||||
|
@ -56,14 +57,14 @@ class LoraFinetuningSingleDevice:
|
|||
# self._device = utils.get_device(device=cfg.device)
|
||||
self.config = config
|
||||
self.request = request
|
||||
self._device = training.utils.get_device(device="cuda")
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(
|
||||
request.training_config.dtype, device=self._device
|
||||
)
|
||||
self.model_id = request.model
|
||||
self.model_id = config.model
|
||||
|
||||
# hardcode it for now and see how it works with get_training_job_artifacts
|
||||
self._output_dir = f"~/.llama/checkpoints/post_training/{request.model_id}"
|
||||
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
|
||||
|
||||
self._log_every_n_steps = 1
|
||||
self._log_peak_memory_stats = False
|
||||
|
@ -111,8 +112,8 @@ class LoraFinetuningSingleDevice:
|
|||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||
|
||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
||||
checkpoint_dir=self.config.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files,
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
# todo: automatically get this info from model
|
||||
model_type="LLAMA3",
|
||||
|
@ -449,7 +450,7 @@ class LoraFinetuningSingleDevice:
|
|||
# ):
|
||||
# torch.cuda.memory._record_memory_history()
|
||||
|
||||
training.utils.batch_to_device(batch, self._device)
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue