feat: make training config fields optional

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-04-02 11:35:23 -04:00
parent 66d6c2580e
commit 9f5543a643
4 changed files with 29 additions and 21 deletions

View file

@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice:
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice:
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice:
"""
The core training loop.
"""
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss: float = 0.0