forked from phoenix-oss/llama-stack-mirror
feat: make training config fields optional (#1861)
# What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
70a7e4d51e
commit
0751a960a5
4 changed files with 29 additions and 21 deletions
17
docs/_static/llama-stack-spec.html
vendored
17
docs/_static/llama-stack-spec.html
vendored
|
@ -9778,13 +9778,16 @@
|
|||
"type": "integer"
|
||||
},
|
||||
"max_steps_per_epoch": {
|
||||
"type": "integer"
|
||||
"type": "integer",
|
||||
"default": 1
|
||||
},
|
||||
"gradient_accumulation_steps": {
|
||||
"type": "integer"
|
||||
"type": "integer",
|
||||
"default": 1
|
||||
},
|
||||
"max_validation_steps": {
|
||||
"type": "integer"
|
||||
"type": "integer",
|
||||
"default": 1
|
||||
},
|
||||
"data_config": {
|
||||
"$ref": "#/components/schemas/DataConfig"
|
||||
|
@ -9804,10 +9807,7 @@
|
|||
"required": [
|
||||
"n_epochs",
|
||||
"max_steps_per_epoch",
|
||||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"data_config",
|
||||
"optimizer_config"
|
||||
"gradient_accumulation_steps"
|
||||
],
|
||||
"title": "TrainingConfig"
|
||||
},
|
||||
|
@ -10983,8 +10983,7 @@
|
|||
"job_uuid",
|
||||
"training_config",
|
||||
"hyperparam_search_config",
|
||||
"logger_config",
|
||||
"model"
|
||||
"logger_config"
|
||||
],
|
||||
"title": "SupervisedFineTuneRequest"
|
||||
},
|
||||
|
|
7
docs/_static/llama-stack-spec.yaml
vendored
7
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6744,10 +6744,13 @@ components:
|
|||
type: integer
|
||||
max_steps_per_epoch:
|
||||
type: integer
|
||||
default: 1
|
||||
gradient_accumulation_steps:
|
||||
type: integer
|
||||
default: 1
|
||||
max_validation_steps:
|
||||
type: integer
|
||||
default: 1
|
||||
data_config:
|
||||
$ref: '#/components/schemas/DataConfig'
|
||||
optimizer_config:
|
||||
|
@ -6762,9 +6765,6 @@ components:
|
|||
- n_epochs
|
||||
- max_steps_per_epoch
|
||||
- gradient_accumulation_steps
|
||||
- max_validation_steps
|
||||
- data_config
|
||||
- optimizer_config
|
||||
title: TrainingConfig
|
||||
PreferenceOptimizeRequest:
|
||||
type: object
|
||||
|
@ -7498,7 +7498,6 @@ components:
|
|||
- training_config
|
||||
- hyperparam_search_config
|
||||
- logger_config
|
||||
- model
|
||||
title: SupervisedFineTuneRequest
|
||||
SyntheticDataGenerateRequest:
|
||||
type: object
|
||||
|
|
|
@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
|
|||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: Optional[int] = 1
|
||||
data_config: Optional[DataConfig] = None
|
||||
optimizer_config: Optional[OptimizerConfig] = None
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
@ -177,9 +177,9 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
|
|
|
@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
EfficiencyConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
|
@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice:
|
|||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
|
||||
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
|
||||
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized.")
|
||||
|
||||
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
|
||||
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
|
@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice:
|
|||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||
log.info("Loss is initialized.")
|
||||
|
||||
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
|
||||
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
|
@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss: float = 0.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue