diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 36bfad49e..cdd6b3b53 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9778,13 +9778,16 @@
"type": "integer"
},
"max_steps_per_epoch": {
- "type": "integer"
+ "type": "integer",
+ "default": 1
},
"gradient_accumulation_steps": {
- "type": "integer"
+ "type": "integer",
+ "default": 1
},
"max_validation_steps": {
- "type": "integer"
+ "type": "integer",
+ "default": 1
},
"data_config": {
"$ref": "#/components/schemas/DataConfig"
@@ -9804,10 +9807,7 @@
"required": [
"n_epochs",
"max_steps_per_epoch",
- "gradient_accumulation_steps",
- "max_validation_steps",
- "data_config",
- "optimizer_config"
+ "gradient_accumulation_steps"
],
"title": "TrainingConfig"
},
@@ -10983,8 +10983,7 @@
"job_uuid",
"training_config",
"hyperparam_search_config",
- "logger_config",
- "model"
+ "logger_config"
],
"title": "SupervisedFineTuneRequest"
},
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 82faf450a..aa8d9456e 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6744,10 +6744,13 @@ components:
type: integer
max_steps_per_epoch:
type: integer
+ default: 1
gradient_accumulation_steps:
type: integer
+ default: 1
max_validation_steps:
type: integer
+ default: 1
data_config:
$ref: '#/components/schemas/DataConfig'
optimizer_config:
@@ -6762,9 +6765,6 @@ components:
- n_epochs
- max_steps_per_epoch
- gradient_accumulation_steps
- - max_validation_steps
- - data_config
- - optimizer_config
title: TrainingConfig
PreferenceOptimizeRequest:
type: object
@@ -7498,7 +7498,6 @@ components:
- training_config
- hyperparam_search_config
- logger_config
- - model
title: SupervisedFineTuneRequest
SyntheticDataGenerateRequest:
type: object
diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py
index d49668e23..e5f1bcb65 100644
--- a/llama_stack/apis/post_training/post_training.py
+++ b/llama_stack/apis/post_training/post_training.py
@@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
- max_steps_per_epoch: int
- gradient_accumulation_steps: int
- max_validation_steps: int
- data_config: DataConfig
- optimizer_config: OptimizerConfig
+ max_steps_per_epoch: int = 1
+ gradient_accumulation_steps: int = 1
+ max_validation_steps: Optional[int] = 1
+ data_config: Optional[DataConfig] = None
+ optimizer_config: Optional[OptimizerConfig] = None
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@@ -177,9 +177,9 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
- model: str = Field(
- default="Llama3.2-3B-Instruct",
- description="Model descriptor from `llama model list`",
+ model: Optional[str] = Field(
+ default=None,
+ description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
index edc1ceb90..04bf86b97 100644
--- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
+++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
@@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
+ DataConfig,
+ EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
@@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice:
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
+ assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
+
+ assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
+
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
@@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
+ assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
@@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice:
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
+ assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
+
self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
@@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice:
"""
The core training loop.
"""
+ assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss: float = 0.0