mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
address comment
This commit is contained in:
parent
2a15a8a005
commit
12eef58543
6 changed files with 58 additions and 87 deletions
|
@ -7,7 +7,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -62,13 +62,6 @@ class TrainingConfig(BaseModel):
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: Optional[str] = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FinetuningAlgorithm(Enum):
|
|
||||||
full = "full"
|
|
||||||
lora = "lora"
|
|
||||||
qat = "qat"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: List[str]
|
||||||
|
@ -172,12 +165,17 @@ class PostTraining(Protocol):
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
|
||||||
algorithm: FinetuningAlgorithm,
|
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
algorithm_config: Optional[LoraFinetuningConfig] = None,
|
model: str = Field(
|
||||||
|
default="Llama3.2-3B-Instruct",
|
||||||
|
description="Model descriptor from `llama model list`",
|
||||||
|
),
|
||||||
|
checkpoint_dir: Optional[str] = None,
|
||||||
|
algorithm_config: Optional[
|
||||||
|
Union[LoraFinetuningConfig, QATFinetuningConfig]
|
||||||
|
] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize")
|
||||||
|
|
|
@ -6,15 +6,8 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class TorchtunePostTrainingConfig(BaseModel):
|
class TorchtunePostTrainingConfig(BaseModel):
|
||||||
model: str = Field(
|
|
||||||
default="Llama3.2-3B-Instruct",
|
|
||||||
description="Model descriptor from `llama model list`",
|
|
||||||
)
|
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
|
||||||
# can override by specifying the directory explicitly
|
|
||||||
checkpoint_dir: Optional[str] = None
|
|
||||||
|
|
|
@ -13,18 +13,6 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingSFTRequest(BaseModel):
|
|
||||||
job_uuid: str
|
|
||||||
model: str
|
|
||||||
algorithm: FinetuningAlgorithm
|
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class TorchtunePostTrainingImpl:
|
class TorchtunePostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||||
|
@ -35,29 +23,25 @@ class TorchtunePostTrainingImpl:
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
|
||||||
algorithm: FinetuningAlgorithm,
|
|
||||||
algorithm_config: LoraFinetuningConfig,
|
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: Optional[str],
|
||||||
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request = PostTrainingSFTRequest(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
model=model,
|
|
||||||
algorithm=algorithm,
|
|
||||||
algorithm_config=algorithm_config,
|
|
||||||
training_config=training_config,
|
|
||||||
hyperparam_search_config=hyperparam_search_config,
|
|
||||||
logger_config=logger_config,
|
|
||||||
)
|
|
||||||
if request.algorithm == FinetuningAlgorithm.lora:
|
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe = LoraFinetuningSingleDevice(
|
||||||
self.config, request, self.datasetio_api
|
self.config,
|
||||||
|
training_config,
|
||||||
|
hyperparam_search_config,
|
||||||
|
logger_config,
|
||||||
|
model,
|
||||||
|
checkpoint_dir,
|
||||||
|
algorithm_config,
|
||||||
|
self.datasetio_api,
|
||||||
)
|
)
|
||||||
await recipe.setup(self.config)
|
await recipe.setup()
|
||||||
await recipe.train()
|
await recipe.train()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -22,12 +22,9 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
MetaReferencePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
from llama_stack.providers.inline.post_training.torchtune.post_training import (
|
|
||||||
PostTrainingSFTRequest,
|
|
||||||
)
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training
|
from torchtune import modules, training
|
||||||
|
@ -64,18 +61,21 @@ class LoraFinetuningSingleDevice:
|
||||||
# and make it work with telemetry
|
# and make it work with telemetry
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferencePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
request: PostTrainingSFTRequest,
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: Optional[str],
|
||||||
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Assume the training only happens on GPU
|
# Assume the training only happens on GPU
|
||||||
self.config = config
|
self.training_config = training_config
|
||||||
self.request = request
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
self._dtype = training.get_dtype(
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
request.training_config.dtype, device=self._device
|
self.model_id = model
|
||||||
)
|
|
||||||
self.model_id = config.model
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
@ -93,7 +93,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
return str(checkpoint_dir)
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
if checkpoint_dir and checkpoint_dir != "null":
|
||||||
self.checkpoint_dir = config.checkpoint_dir
|
self.checkpoint_dir = config.checkpoint_dir
|
||||||
else:
|
else:
|
||||||
model = resolve_model(self.model_id)
|
model = resolve_model(self.model_id)
|
||||||
|
@ -102,29 +102,27 @@ class LoraFinetuningSingleDevice:
|
||||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
||||||
|
|
||||||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
self.total_epochs = request.training_config.n_epochs
|
self.total_epochs = training_config.n_epochs
|
||||||
self._shuffle = request.training_config.data_config.shuffle
|
self._shuffle = training_config.data_config.shuffle
|
||||||
self._batch_size = request.training_config.data_config.batch_size
|
self._batch_size = training_config.data_config.batch_size
|
||||||
|
|
||||||
# this is important for debugging purpose
|
# this is important for debugging purpose
|
||||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
self._gradient_accumulation_steps = (
|
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||||
request.training_config.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
self._clip_grad_norm = 1.0
|
self._clip_grad_norm = 1.0
|
||||||
self._enable_activation_checkpointing = (
|
self._enable_activation_checkpointing = (
|
||||||
(request.training_config.efficiency_config.enable_activation_checkpointing)
|
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||||
if request.training_config.efficiency_config
|
if training_config.efficiency_config
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
self._enable_activation_offloading = (
|
self._enable_activation_offloading = (
|
||||||
(request.training_config.efficiency_config.enable_activation_offloading)
|
(training_config.efficiency_config.enable_activation_offloading)
|
||||||
if request.training_config.efficiency_config
|
if training_config.efficiency_config
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -150,7 +148,7 @@ class LoraFinetuningSingleDevice:
|
||||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
async def setup(self) -> None:
|
||||||
# temporily log to local disk, will figure out how to interop with telemetry
|
# temporily log to local disk, will figure out how to interop with telemetry
|
||||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||||
|
|
||||||
|
@ -168,7 +166,7 @@ class LoraFinetuningSingleDevice:
|
||||||
log.info("Tokenizer is initialized from file.")
|
log.info("Tokenizer is initialized from file.")
|
||||||
|
|
||||||
self._optimizer = await self._setup_optimizer(
|
self._optimizer = await self._setup_optimizer(
|
||||||
optimizer_config=self.request.training_config.optimizer_config
|
optimizer_config=self.training_config.optimizer_config
|
||||||
)
|
)
|
||||||
log.info("Optimizer is initialized.")
|
log.info("Optimizer is initialized.")
|
||||||
|
|
||||||
|
@ -200,7 +198,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Learning rate scheduler can only be set up after number of steps
|
# Learning rate scheduler can only be set up after number of steps
|
||||||
# has been computed
|
# has been computed
|
||||||
self._lr_scheduler = await self._setup_lr_scheduler(
|
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||||
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
|
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
|
||||||
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||||
last_epoch=self.global_step - 1,
|
last_epoch=self.global_step - 1,
|
||||||
)
|
)
|
||||||
|
@ -218,12 +216,12 @@ class LoraFinetuningSingleDevice:
|
||||||
base_model_state_dict: Dict[str, Any],
|
base_model_state_dict: Dict[str, Any],
|
||||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
self._lora_rank = self.request.algorithm_config.rank
|
self._lora_rank = self.algorithm_config.rank
|
||||||
self._lora_alpha = self.request.algorithm_config.alpha
|
self._lora_alpha = self.algorithm_config.alpha
|
||||||
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
|
||||||
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
|
||||||
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
|
||||||
self._use_dora = self.request.algorithm_config.use_dora or False
|
self._use_dora = self.algorithm_config.use_dora or False
|
||||||
|
|
||||||
with training.set_default_dtype(self._dtype), self._device:
|
with training.set_default_dtype(self._dtype), self._device:
|
||||||
model_type = utils.get_model_type(self.model_id)
|
model_type = utils.get_model_type(self.model_id)
|
||||||
|
@ -311,7 +309,7 @@ class LoraFinetuningSingleDevice:
|
||||||
) -> Tuple[DistributedSampler, DataLoader]:
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
async def fetch_rows():
|
async def fetch_rows():
|
||||||
return await self.datasetio_api.get_rows_paginated(
|
return await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=self.request.training_config.data_config.dataset_id,
|
dataset_id=self.training_config.data_config.dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_type="inline::torchtune",
|
provider_type="inline::torchtune",
|
||||||
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||||
module="llama_stack.providers.inline.post_training.torchtune",
|
module="llama_stack.providers.inline.post_training.torchtune",
|
||||||
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
|
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
],
|
],
|
||||||
|
|
|
@ -49,9 +49,7 @@ providers:
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: meta-reference-post-training
|
- provider_id: meta-reference-post-training
|
||||||
provider_type: inline::torchtune
|
provider_type: inline::torchtune
|
||||||
config:
|
config: {}
|
||||||
model: ${env.POST_TRAINING_MODEL}
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
|
||||||
|
|
||||||
metadata_store:
|
metadata_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue