This commit is contained in:
Botao Chen 2024-12-09 13:35:44 -08:00
parent 5a628d32c8
commit 9c1ae088f9
4 changed files with 7 additions and 16 deletions

View file

@ -182,12 +182,8 @@ class PostTraining(Protocol):
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],

View file

@ -10,6 +10,8 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import TorchtunePostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
@ -21,5 +23,4 @@ async def get_provider_impl(
config,
deps[Api.datasetio],
)
# await impl.initialize()
return impl

View file

@ -51,12 +51,8 @@ class TorchtunePostTrainingImpl:
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],

View file

@ -57,8 +57,8 @@ class LoraFinetuningSingleDevice:
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# TODO @markchen1015 figure out the logging for this training recipe
# and make it work with telemetry
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
def __init__(
self,
config: TorchtunePostTrainingConfig,
@ -70,7 +70,6 @@ class LoraFinetuningSingleDevice:
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO,
) -> None:
# Assume the training only happens on GPU
self.training_config = training_config
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
@ -149,7 +148,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
@ -316,7 +314,7 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows()
rows = all_rows.rows
# Curretly only support instruct dataset
# Curretly only support alpaca instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset(
rows,