mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
bfc782c054
commit
6c709abc4d
3 changed files with 28 additions and 54 deletions
|
@ -180,16 +180,16 @@ class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/supervised-fine-tune")
|
@webmethod(route="/post-training/supervised-fine-tune")
|
||||||
def supervised_fine_tune(
|
def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: Optional[str],
|
job_uuid: str,
|
||||||
model: Optional[str],
|
model: str,
|
||||||
dataset_id: Optional[str],
|
dataset_id: str,
|
||||||
validation_dataset_id: Optional[str],
|
validation_dataset_id: str,
|
||||||
algorithm: Optional[FinetuningAlgorithm],
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: Optional[LoraFinetuningConfig],
|
algorithm_config: LoraFinetuningConfig,
|
||||||
optimizer_config: Optional[OptimizerConfig],
|
optimizer_config: OptimizerConfig,
|
||||||
training_config: Optional[TrainingConfig],
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Optional[Dict[str, Any]],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Optional[Dict[str, Any]],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize")
|
||||||
|
|
|
@ -20,46 +20,18 @@ class MetaReferencePostTrainingImpl:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
|
|
||||||
LoraFinetuningConfig(
|
|
||||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
|
||||||
apply_lora_to_mlp=True,
|
|
||||||
apply_lora_to_output=False,
|
|
||||||
rank=8,
|
|
||||||
alpha=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
OptimizerConfig(
|
|
||||||
optimizer_type=OptimizerType.adamw,
|
|
||||||
lr=3e-4,
|
|
||||||
lr_min=3e-5,
|
|
||||||
weight_decay=0.1,
|
|
||||||
num_warmup_steps=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
TrainingConfig(
|
|
||||||
dtype="bf16",
|
|
||||||
n_epochs=1,
|
|
||||||
max_steps_per_epoch=10,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=1,
|
|
||||||
enable_activation_checkpointing=False,
|
|
||||||
memory_efficient_fsdp_wrap=False,
|
|
||||||
fsdp_cpu_offload=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def supervised_fine_tune(
|
def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: Optional[str] = "1234",
|
job_uuid: str,
|
||||||
model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct",
|
model: str,
|
||||||
dataset_id: Optional[str] = "alpaca",
|
dataset_id: str,
|
||||||
validation_dataset_id: Optional[str] = "alpaca",
|
validation_dataset_id: str,
|
||||||
algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora,
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig,
|
algorithm_config: LoraFinetuningConfig,
|
||||||
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig,
|
optimizer_config: OptimizerConfig,
|
||||||
training_config: Optional[TrainingConfig] = TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Optional[Dict[str, Any]] = {},
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Optional[Dict[str, Any]] = {},
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = PostTrainingSFTRequest(
|
request = PostTrainingSFTRequest(
|
||||||
|
@ -71,6 +43,7 @@ class MetaReferencePostTrainingImpl:
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
training_config=training_config,
|
training_config=training_config,
|
||||||
|
hyperparam_search_config=hyperparam_search_config,
|
||||||
logger_config=logger_config,
|
logger_config=logger_config,
|
||||||
)
|
)
|
||||||
if request.algorithm == FinetuningAlgorithm.lora:
|
if request.algorithm == FinetuningAlgorithm.lora:
|
||||||
|
|
|
@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torchtune import utils as torchtune_utils
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||||
|
|
||||||
|
@ -56,14 +57,14 @@ class LoraFinetuningSingleDevice:
|
||||||
# self._device = utils.get_device(device=cfg.device)
|
# self._device = utils.get_device(device=cfg.device)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.request = request
|
self.request = request
|
||||||
self._device = training.utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
self._dtype = training.get_dtype(
|
self._dtype = training.get_dtype(
|
||||||
request.training_config.dtype, device=self._device
|
request.training_config.dtype, device=self._device
|
||||||
)
|
)
|
||||||
self.model_id = request.model
|
self.model_id = config.model
|
||||||
|
|
||||||
# hardcode it for now and see how it works with get_training_job_artifacts
|
# hardcode it for now and see how it works with get_training_job_artifacts
|
||||||
self._output_dir = f"~/.llama/checkpoints/post_training/{request.model_id}"
|
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
|
||||||
|
|
||||||
self._log_every_n_steps = 1
|
self._log_every_n_steps = 1
|
||||||
self._log_peak_memory_stats = False
|
self._log_peak_memory_stats = False
|
||||||
|
@ -111,8 +112,8 @@ class LoraFinetuningSingleDevice:
|
||||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||||
|
|
||||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
self._checkpointer = training.FullModelMetaCheckpointer(
|
||||||
checkpoint_dir=self.config.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_files=get_checkpoint_files,
|
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||||
output_dir=self._output_dir,
|
output_dir=self._output_dir,
|
||||||
# todo: automatically get this info from model
|
# todo: automatically get this info from model
|
||||||
model_type="LLAMA3",
|
model_type="LLAMA3",
|
||||||
|
@ -449,7 +450,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# ):
|
# ):
|
||||||
# torch.cuda.memory._record_memory_history()
|
# torch.cuda.memory._record_memory_history()
|
||||||
|
|
||||||
training.utils.batch_to_device(batch, self._device)
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
# Calculate the number of unmasked tokens in the current batch
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
# and increment the total number of tokens seen in the step
|
# and increment the total number of tokens seen in the step
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue