address comments

This commit is contained in:
Botao Chen 2024-12-10 20:50:17 -08:00
parent 9c1ae088f9
commit 68ebf8a8da
6 changed files with 97 additions and 29 deletions

View file

@ -6,12 +6,12 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
@ -79,6 +79,11 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
]
@json_schema_type @json_schema_type
class PostTrainingJobLogStream(BaseModel): class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job.""" """Stream of logs from a finetuning job."""
@ -173,9 +178,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[ algorithm_config: Optional[AlgorithmConfig] = None,
Union[LoraFinetuningConfig, QATFinetuningConfig]
] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize")

View file

@ -22,5 +22,6 @@ async def get_provider_impl(
impl = TorchtunePostTrainingImpl( impl = TorchtunePostTrainingImpl(
config, config,
deps[Api.datasetio], deps[Api.datasetio],
deps[Api.datasets],
) )
return impl return impl

View file

@ -15,10 +15,14 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
class TorchtunePostTrainingImpl: class TorchtunePostTrainingImpl:
def __init__( def __init__(
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO self,
config: TorchtunePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None: ) -> None:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
@ -40,6 +44,7 @@ class TorchtunePostTrainingImpl:
checkpoint_dir, checkpoint_dir,
algorithm_config, algorithm_config,
self.datasetio_api, self.datasetio_api,
self.datasets_api,
) )
await recipe.setup() await recipe.setup()
await recipe.train() await recipe.train()
@ -58,7 +63,7 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
# TODO @markchen1015 impelment below APIs # TODO @SLR722 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ... async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs # sends SSE stream of logs

View file

@ -69,6 +69,7 @@ class LoraFinetuningSingleDevice:
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None: ) -> None:
self.training_config = training_config self.training_config = training_config
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
@ -98,7 +99,7 @@ class LoraFinetuningSingleDevice:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @markchen1015 make it work with get_training_job_artifacts # TODO @SLR722 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/" self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
@ -126,6 +127,7 @@ class LoraFinetuningSingleDevice:
) )
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def load_checkpoint(self): async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]: def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
@ -142,7 +144,7 @@ class LoraFinetuningSingleDevice:
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir), checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir, output_dir=self._output_dir,
model_type=utils.get_checkpointer_model_type(self.model_id), model_type=await utils.get_checkpointer_model_type(self.model_id),
) )
checkpoint_dict = self._checkpointer.load_checkpoint() checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict return checkpoint_dict
@ -222,7 +224,7 @@ class LoraFinetuningSingleDevice:
self._use_dora = self.algorithm_config.use_dora or False self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device: with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id) model_type = await utils.get_model_definition(self.model_id)
model = model_type( model = model_type(
lora_attn_modules=self._lora_attn_modules, lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp, apply_lora_to_mlp=self._apply_lora_to_mlp,
@ -289,7 +291,7 @@ class LoraFinetuningSingleDevice:
self, self,
) -> Llama3Tokenizer: ) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model" tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = utils.get_tokenizer_type(self.model_id) tokenizer_type = await utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path) return tokenizer_type(path=tokenizer_path)
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
@ -305,9 +307,11 @@ class LoraFinetuningSingleDevice:
async def _setup_data( async def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
dataset_id = self.training_config.data_config.dataset_id
async def fetch_rows(): async def fetch_rows():
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.get_rows_paginated(
dataset_id=self.training_config.data_config.dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -315,7 +319,13 @@ class LoraFinetuningSingleDevice:
rows = all_rows.rows rows = all_rows.rows
# Curretly only support alpaca instruct dataset # Curretly only support alpaca instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types # TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
)
ds = SFTDataset( ds = SFTDataset(
rows, rows,
message_transform=AlpacaToMessages(train_on_input=False), message_transform=AlpacaToMessages(train_on_input=False),

View file

@ -10,49 +10,97 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Callable, Dict from enum import Enum
from typing import Any, Callable, Dict, List
import torch import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
LORA_MODEL_TYPES: Dict[str, Any] = {
"Llama3.2-3B-Instruct": lora_llama3_2_3b, class ColumnName(Enum):
"Llama-3-8B-Instruct": lora_llama3_8b, instruction = "instruction"
input = "input"
output = "output"
text = "text"
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
"Llama3.2-3B-Instruct": {
"model_definition": lora_llama3_2_3b,
"tokenizer_type": llama3_tokenizer,
"checkpoint_type": "LLAMA3_2",
},
"Llama-3-8B-Instruct": {
"model_definition": lora_llama3_8b,
"tokenizer_type": llama3_tokenizer,
"checkpoint_type": "LLAMA3",
},
} }
TOKENIZER_TYPES: Dict[str, Any] = { EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
"Llama3.2-3B-Instruct": llama3_tokenizer, "alpaca": [
"Llama-3-8B-Instruct": llama3_tokenizer, {
} ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
CHECKPOINT_MODEL_TYPES: Dict[str, str] = { ColumnName.output.value: StringType(),
"Llama3.2-3B-Instruct": "LLAMA3_2", ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
} }
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def get_model_type( async def get_model_definition(
model_id: str, model_id: str,
) -> BuildLoraModelCallable: ) -> BuildLoraModelCallable:
model = resolve_model(model_id) model = resolve_model(model_id)
return LORA_MODEL_TYPES[model.core_model_id.value] if model is None or model.core_model_id.value not in MODEL_CONFIGS:
raise ValueError(f"Model {model_id} is not supported.")
return MODEL_CONFIGS[model.core_model_id.value]["model_definition"]
def get_tokenizer_type( async def get_tokenizer_type(
model_id: str, model_id: str,
) -> BuildTokenizerCallable: ) -> BuildTokenizerCallable:
model = resolve_model(model_id) model = resolve_model(model_id)
return TOKENIZER_TYPES[model.core_model_id.value] return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"]
def get_checkpointer_model_type( async def get_checkpointer_model_type(
model_id: str, model_id: str,
) -> str: ) -> str:
model = resolve_model(model_id) model = resolve_model(model_id)
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value] return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"]
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}"
)

View file

@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets,
], ],
), ),
] ]