diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 758817ac0..3c6918786 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,12 +6,12 @@ from datetime import datetime from enum import Enum - from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 @@ -79,6 +79,11 @@ class QATFinetuningConfig(BaseModel): group_size: int +AlgorithmConfig = Annotated[ + Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type") +] + + @json_schema_type class PostTrainingJobLogStream(BaseModel): """Stream of logs from a finetuning job.""" @@ -173,9 +178,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[ - Union[LoraFinetuningConfig, QATFinetuningConfig] - ] = None, + algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py index 247ae22b2..7ef8eee01 100644 --- a/llama_stack/providers/inline/post_training/torchtune/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -22,5 +22,6 @@ async def get_provider_impl( impl = TorchtunePostTrainingImpl( config, deps[Api.datasetio], + deps[Api.datasets], ) return impl diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index f33ca059a..1987086e1 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -15,10 +15,14 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin class TorchtunePostTrainingImpl: def __init__( - self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO + self, + config: TorchtunePostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, ) -> None: self.config = config self.datasetio_api = datasetio_api + self.datasets_api = datasets async def supervised_fine_tune( self, @@ -40,6 +44,7 @@ class TorchtunePostTrainingImpl: checkpoint_dir, algorithm_config, self.datasetio_api, + self.datasets_api, ) await recipe.setup() await recipe.train() @@ -58,7 +63,7 @@ class TorchtunePostTrainingImpl: logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - # TODO @markchen1015 impelment below APIs + # TODO @SLR722 impelment below APIs async def get_training_jobs(self) -> List[PostTrainingJob]: ... # sends SSE stream of logs diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 17d3cbc2c..7873c7c6f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -69,6 +69,7 @@ class LoraFinetuningSingleDevice: checkpoint_dir: Optional[str], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], datasetio_api: DatasetIO, + datasets_api: Datasets, ) -> None: self.training_config = training_config self.algorithm_config = algorithm_config @@ -98,7 +99,7 @@ class LoraFinetuningSingleDevice: model = resolve_model(self.model_id) self.checkpoint_dir = model_checkpoint_dir(model) - # TODO @markchen1015 make it work with get_training_job_artifacts + # TODO @SLR722 make it work with get_training_job_artifacts self._output_dir = self.checkpoint_dir + "/posting_training/" self.seed = training.set_seed(seed=config.torch_seed) @@ -126,6 +127,7 @@ class LoraFinetuningSingleDevice: ) self.datasetio_api = datasetio_api + self.datasets_api = datasets_api async def load_checkpoint(self): def get_checkpoint_files(checkpoint_dir: str) -> List[str]: @@ -142,7 +144,7 @@ class LoraFinetuningSingleDevice: checkpoint_dir=self.checkpoint_dir, checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, - model_type=utils.get_checkpointer_model_type(self.model_id), + model_type=await utils.get_checkpointer_model_type(self.model_id), ) checkpoint_dict = self._checkpointer.load_checkpoint() return checkpoint_dict @@ -222,7 +224,7 @@ class LoraFinetuningSingleDevice: self._use_dora = self.algorithm_config.use_dora or False with training.set_default_dtype(self._dtype), self._device: - model_type = utils.get_model_type(self.model_id) + model_type = await utils.get_model_definition(self.model_id) model = model_type( lora_attn_modules=self._lora_attn_modules, apply_lora_to_mlp=self._apply_lora_to_mlp, @@ -289,7 +291,7 @@ class LoraFinetuningSingleDevice: self, ) -> Llama3Tokenizer: tokenizer_path = self.checkpoint_dir + "/tokenizer.model" - tokenizer_type = utils.get_tokenizer_type(self.model_id) + tokenizer_type = await utils.get_tokenizer_type(self.model_id) return tokenizer_type(path=tokenizer_path) async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: @@ -305,9 +307,11 @@ class LoraFinetuningSingleDevice: async def _setup_data( self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int ) -> Tuple[DistributedSampler, DataLoader]: + dataset_id = self.training_config.data_config.dataset_id + async def fetch_rows(): return await self.datasetio_api.get_rows_paginated( - dataset_id=self.training_config.data_config.dataset_id, + dataset_id=dataset_id, rows_in_page=-1, ) @@ -315,7 +319,13 @@ class LoraFinetuningSingleDevice: rows = all_rows.rows # Curretly only support alpaca instruct dataset - # TODO @markchen1015 make the message_transform swappable and support more dataset types + # TODO @SLR722 make the message_transform swappable and support more dataset types + # TODO @SLR722 make the input dataset schema more flexible by exposing column_map + await utils.validate_input_dataset_schema( + datasets_api=self.datasets_api, + dataset_id=dataset_id, + dataset_type="alpaca", + ) ds = SFTDataset( rows, message_transform=AlpacaToMessages(train_on_input=False), diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py index 93c7ef189..6cbee8766 100644 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -10,49 +10,97 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from enum import Enum +from typing import Any, Callable, Dict, List import torch +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.common.type_system import * # noqa from llama_models.sku_list import resolve_model +from llama_stack.apis.common.type_system import ParamType from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b -LORA_MODEL_TYPES: Dict[str, Any] = { - "Llama3.2-3B-Instruct": lora_llama3_2_3b, - "Llama-3-8B-Instruct": lora_llama3_8b, + +class ColumnName(Enum): + instruction = "instruction" + input = "input" + output = "output" + text = "text" + + +MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { + "Llama3.2-3B-Instruct": { + "model_definition": lora_llama3_2_3b, + "tokenizer_type": llama3_tokenizer, + "checkpoint_type": "LLAMA3_2", + }, + "Llama-3-8B-Instruct": { + "model_definition": lora_llama3_8b, + "tokenizer_type": llama3_tokenizer, + "checkpoint_type": "LLAMA3", + }, } -TOKENIZER_TYPES: Dict[str, Any] = { - "Llama3.2-3B-Instruct": llama3_tokenizer, - "Llama-3-8B-Instruct": llama3_tokenizer, -} - -CHECKPOINT_MODEL_TYPES: Dict[str, str] = { - "Llama3.2-3B-Instruct": "LLAMA3_2", +EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = { + "alpaca": [ + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + ColumnName.text.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.input.value: StringType(), + ColumnName.output.value: StringType(), + }, + { + ColumnName.instruction.value: StringType(), + ColumnName.output.value: StringType(), + }, + ] } BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] -def get_model_type( +async def get_model_definition( model_id: str, ) -> BuildLoraModelCallable: model = resolve_model(model_id) - return LORA_MODEL_TYPES[model.core_model_id.value] + if model is None or model.core_model_id.value not in MODEL_CONFIGS: + raise ValueError(f"Model {model_id} is not supported.") + return MODEL_CONFIGS[model.core_model_id.value]["model_definition"] -def get_tokenizer_type( +async def get_tokenizer_type( model_id: str, ) -> BuildTokenizerCallable: model = resolve_model(model_id) - return TOKENIZER_TYPES[model.core_model_id.value] + return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"] -def get_checkpointer_model_type( +async def get_checkpointer_model_type( model_id: str, ) -> str: model = resolve_model(model_id) - return CHECKPOINT_MODEL_TYPES[model.core_model_id.value] + return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"] + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]: + raise ValueError( + f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}" + ) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 2c9fdd43d..af8b660fa 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", api_dependencies=[ Api.datasetio, + Api.datasets, ], ), ]