mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-20 16:58:41 +00:00
Merge branch 'main' into post_training_v4
This commit is contained in:
commit
018dce89ca
287 changed files with 13743 additions and 4540 deletions
|
|
@ -22,5 +22,6 @@ async def get_provider_impl(
|
|||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -10,49 +10,130 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
LORA_MODEL_TYPES: Dict[str, Any] = {
|
||||
"Llama3.2-3B-Instruct": lora_llama3_2_3b,
|
||||
"Llama-3-8B-Instruct": lora_llama3_8b,
|
||||
|
||||
class ColumnName(Enum):
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
}
|
||||
|
||||
TOKENIZER_TYPES: Dict[str, Any] = {
|
||||
"Llama3.2-3B-Instruct": llama3_tokenizer,
|
||||
"Llama-3-8B-Instruct": llama3_tokenizer,
|
||||
}
|
||||
|
||||
CHECKPOINT_MODEL_TYPES: Dict[str, str] = {
|
||||
"Llama3.2-3B-Instruct": "LLAMA3_2",
|
||||
}
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
ColumnName.text.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def get_model_type(
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_definition(
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = resolve_model(model_id)
|
||||
return LORA_MODEL_TYPES[model.core_model_id.value]
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
||||
|
||||
def get_tokenizer_type(
|
||||
async def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = resolve_model(model_id)
|
||||
return TOKENIZER_TYPES[model.core_model_id.value]
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
||||
|
||||
def get_checkpointer_model_type(
|
||||
async def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
model = resolve_model(model_id)
|
||||
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value]
|
||||
"""
|
||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
|
|||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs_status = {}
|
||||
|
|
@ -33,10 +37,11 @@ class TorchtunePostTrainingImpl:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs_list:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
for job in self.jobs_list:
|
||||
if job_uuid == job.job_uuid:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
|
|
@ -59,6 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
|
|
|
|||
|
|
@ -75,11 +75,16 @@ class LoraFinetuningSingleDevice:
|
|||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError(
|
||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||
)
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
|
|
@ -107,7 +112,6 @@ class LoraFinetuningSingleDevice:
|
|||
model = resolve_model(self.model_id)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
|
|
@ -135,6 +139,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
|
|
@ -153,7 +158,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
model_type=utils.get_checkpointer_model_type(self.model_id),
|
||||
model_type=await utils.get_checkpointer_model_type(self.model_id),
|
||||
)
|
||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||
return checkpoint_dict
|
||||
|
|
@ -241,7 +246,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._use_dora = self.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = utils.get_model_type(self.model_id)
|
||||
model_type = await utils.get_model_definition(self.model_id)
|
||||
model = model_type(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
|
|
@ -308,7 +313,7 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
) -> Llama3Tokenizer:
|
||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||
tokenizer_type = utils.get_tokenizer_type(self.model_id)
|
||||
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
|
||||
return tokenizer_type(path=tokenizer_path)
|
||||
|
||||
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||
|
|
@ -338,7 +343,13 @@ class LoraFinetuningSingleDevice:
|
|||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support alpaca instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||
await utils.validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type="alpaca",
|
||||
)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue