This commit is contained in:
Botao Chen 2024-12-09 13:35:44 -08:00
parent 5a628d32c8
commit 9c1ae088f9
4 changed files with 7 additions and 16 deletions

View file

@ -182,12 +182,8 @@ class PostTraining(Protocol):
async def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: DPOAlignmentConfig, algorithm_config: DPOAlignmentConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],

View file

@ -10,6 +10,8 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import TorchtunePostTrainingConfig from .config import TorchtunePostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl( async def get_provider_impl(
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
@ -21,5 +23,4 @@ async def get_provider_impl(
config, config,
deps[Api.datasetio], deps[Api.datasetio],
) )
# await impl.initialize()
return impl return impl

View file

@ -51,12 +51,8 @@ class TorchtunePostTrainingImpl:
async def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: DPOAlignmentConfig, algorithm_config: DPOAlignmentConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],

View file

@ -57,8 +57,8 @@ class LoraFinetuningSingleDevice:
# Resume from checkpoint hasn't been supported yet # Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet # Validation hasn't been supported yet
# TODO @markchen1015 figure out the logging for this training recipe # Currently logging only logs limited training metrics to local disk
# and make it work with telemetry # will figure out more loggings and how it works with telemetry in future PRs
def __init__( def __init__(
self, self,
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
@ -70,7 +70,6 @@ class LoraFinetuningSingleDevice:
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
) -> None: ) -> None:
# Assume the training only happens on GPU
self.training_config = training_config self.training_config = training_config
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
@ -149,7 +148,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict return checkpoint_dict
async def setup(self) -> None: async def setup(self) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir) self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint() checkpoint_dict = await self.load_checkpoint()
@ -316,7 +314,7 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows() all_rows = await fetch_rows()
rows = all_rows.rows rows = all_rows.rows
# Curretly only support instruct dataset # Curretly only support alpaca instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types # TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset( ds = SFTDataset(
rows, rows,