mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
address comment
This commit is contained in:
parent
2a15a8a005
commit
12eef58543
6 changed files with 58 additions and 87 deletions
|
@ -7,7 +7,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -62,13 +62,6 @@ class TrainingConfig(BaseModel):
|
|||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FinetuningAlgorithm(Enum):
|
||||
full = "full"
|
||||
lora = "lora"
|
||||
qat = "qat"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
lora_attn_modules: List[str]
|
||||
|
@ -172,12 +165,17 @@ class PostTraining(Protocol):
|
|||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
algorithm_config: Optional[LoraFinetuningConfig] = None,
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[
|
||||
Union[LoraFinetuningConfig, QATFinetuningConfig]
|
||||
] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
|
|
|
@ -6,15 +6,8 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
torch_seed: Optional[int] = None
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
|
|
@ -13,18 +13,6 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
|
|||
)
|
||||
|
||||
|
||||
class PostTrainingSFTRequest(BaseModel):
|
||||
job_uuid: str
|
||||
model: str
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||
|
@ -35,29 +23,25 @@ class TorchtunePostTrainingImpl:
|
|||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
) -> PostTrainingJob:
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = PostTrainingSFTRequest(
|
||||
job_uuid=job_uuid,
|
||||
model=model,
|
||||
algorithm=algorithm,
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config=hyperparam_search_config,
|
||||
logger_config=logger_config,
|
||||
)
|
||||
if request.algorithm == FinetuningAlgorithm.lora:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config, request, self.datasetio_api
|
||||
self.config,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
)
|
||||
await recipe.setup(self.config)
|
||||
await recipe.setup()
|
||||
await recipe.train()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -22,12 +22,9 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
|||
|
||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from llama_stack.providers.inline.post_training.torchtune.post_training import (
|
||||
PostTrainingSFTRequest,
|
||||
)
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
|
@ -64,18 +61,21 @@ class LoraFinetuningSingleDevice:
|
|||
# and make it work with telemetry
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
request: PostTrainingSFTRequest,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
# Assume the training only happens on GPU
|
||||
self.config = config
|
||||
self.request = request
|
||||
self.training_config = training_config
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(
|
||||
request.training_config.dtype, device=self._device
|
||||
)
|
||||
self.model_id = config.model
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
self.model_id = model
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
@ -93,7 +93,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
|
@ -102,29 +102,27 @@ class LoraFinetuningSingleDevice:
|
|||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = request.training_config.n_epochs
|
||||
self._shuffle = request.training_config.data_config.shuffle
|
||||
self._batch_size = request.training_config.data_config.batch_size
|
||||
self.total_epochs = training_config.n_epochs
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = (
|
||||
request.training_config.gradient_accumulation_steps
|
||||
)
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(request.training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if request.training_config.efficiency_config
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(request.training_config.efficiency_config.enable_activation_offloading)
|
||||
if request.training_config.efficiency_config
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
|
@ -150,7 +148,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||
return checkpoint_dict
|
||||
|
||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||
async def setup(self) -> None:
|
||||
# temporily log to local disk, will figure out how to interop with telemetry
|
||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||
|
||||
|
@ -168,7 +166,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Tokenizer is initialized from file.")
|
||||
|
||||
self._optimizer = await self._setup_optimizer(
|
||||
optimizer_config=self.request.training_config.optimizer_config
|
||||
optimizer_config=self.training_config.optimizer_config
|
||||
)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
|
@ -200,7 +198,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Learning rate scheduler can only be set up after number of steps
|
||||
# has been computed
|
||||
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
|
||||
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
|
||||
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
@ -218,12 +216,12 @@ class LoraFinetuningSingleDevice:
|
|||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.request.algorithm_config.rank
|
||||
self._lora_alpha = self.request.algorithm_config.alpha
|
||||
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.request.algorithm_config.use_dora or False
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = utils.get_model_type(self.model_id)
|
||||
|
@ -311,7 +309,7 @@ class LoraFinetuningSingleDevice:
|
|||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows():
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.request.training_config.data_config.dataset_id,
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_type="inline::torchtune",
|
||||
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
],
|
||||
|
|
|
@ -49,9 +49,7 @@ providers:
|
|||
post_training:
|
||||
- provider_id: meta-reference-post-training
|
||||
provider_type: inline::torchtune
|
||||
config:
|
||||
model: ${env.POST_TRAINING_MODEL}
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
config: {}
|
||||
|
||||
metadata_store:
|
||||
namespace: null
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue