forked from phoenix-oss/llama-stack-mirror
feat: make training config fields optional (#1861)
# What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
70a7e4d51e
commit
0751a960a5
4 changed files with 29 additions and 21 deletions
17
docs/_static/llama-stack-spec.html
vendored
17
docs/_static/llama-stack-spec.html
vendored
|
@ -9778,13 +9778,16 @@
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
"max_steps_per_epoch": {
|
"max_steps_per_epoch": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
|
"default": 1
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": {
|
"gradient_accumulation_steps": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
|
"default": 1
|
||||||
},
|
},
|
||||||
"max_validation_steps": {
|
"max_validation_steps": {
|
||||||
"type": "integer"
|
"type": "integer",
|
||||||
|
"default": 1
|
||||||
},
|
},
|
||||||
"data_config": {
|
"data_config": {
|
||||||
"$ref": "#/components/schemas/DataConfig"
|
"$ref": "#/components/schemas/DataConfig"
|
||||||
|
@ -9804,10 +9807,7 @@
|
||||||
"required": [
|
"required": [
|
||||||
"n_epochs",
|
"n_epochs",
|
||||||
"max_steps_per_epoch",
|
"max_steps_per_epoch",
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps"
|
||||||
"max_validation_steps",
|
|
||||||
"data_config",
|
|
||||||
"optimizer_config"
|
|
||||||
],
|
],
|
||||||
"title": "TrainingConfig"
|
"title": "TrainingConfig"
|
||||||
},
|
},
|
||||||
|
@ -10983,8 +10983,7 @@
|
||||||
"job_uuid",
|
"job_uuid",
|
||||||
"training_config",
|
"training_config",
|
||||||
"hyperparam_search_config",
|
"hyperparam_search_config",
|
||||||
"logger_config",
|
"logger_config"
|
||||||
"model"
|
|
||||||
],
|
],
|
||||||
"title": "SupervisedFineTuneRequest"
|
"title": "SupervisedFineTuneRequest"
|
||||||
},
|
},
|
||||||
|
|
7
docs/_static/llama-stack-spec.yaml
vendored
7
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6744,10 +6744,13 @@ components:
|
||||||
type: integer
|
type: integer
|
||||||
max_steps_per_epoch:
|
max_steps_per_epoch:
|
||||||
type: integer
|
type: integer
|
||||||
|
default: 1
|
||||||
gradient_accumulation_steps:
|
gradient_accumulation_steps:
|
||||||
type: integer
|
type: integer
|
||||||
|
default: 1
|
||||||
max_validation_steps:
|
max_validation_steps:
|
||||||
type: integer
|
type: integer
|
||||||
|
default: 1
|
||||||
data_config:
|
data_config:
|
||||||
$ref: '#/components/schemas/DataConfig'
|
$ref: '#/components/schemas/DataConfig'
|
||||||
optimizer_config:
|
optimizer_config:
|
||||||
|
@ -6762,9 +6765,6 @@ components:
|
||||||
- n_epochs
|
- n_epochs
|
||||||
- max_steps_per_epoch
|
- max_steps_per_epoch
|
||||||
- gradient_accumulation_steps
|
- gradient_accumulation_steps
|
||||||
- max_validation_steps
|
|
||||||
- data_config
|
|
||||||
- optimizer_config
|
|
||||||
title: TrainingConfig
|
title: TrainingConfig
|
||||||
PreferenceOptimizeRequest:
|
PreferenceOptimizeRequest:
|
||||||
type: object
|
type: object
|
||||||
|
@ -7498,7 +7498,6 @@ components:
|
||||||
- training_config
|
- training_config
|
||||||
- hyperparam_search_config
|
- hyperparam_search_config
|
||||||
- logger_config
|
- logger_config
|
||||||
- model
|
|
||||||
title: SupervisedFineTuneRequest
|
title: SupervisedFineTuneRequest
|
||||||
SyntheticDataGenerateRequest:
|
SyntheticDataGenerateRequest:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: int
|
max_validation_steps: Optional[int] = 1
|
||||||
data_config: DataConfig
|
data_config: Optional[DataConfig] = None
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: Optional[OptimizerConfig] = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: Optional[str] = "bf16"
|
||||||
|
|
||||||
|
@ -177,9 +177,9 @@ class PostTraining(Protocol):
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
model: str = Field(
|
model: Optional[str] = Field(
|
||||||
default="Llama3.2-3B-Instruct",
|
default=None,
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
|
|
|
@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
|
DataConfig,
|
||||||
|
EfficiencyConfig,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
OptimizerConfig,
|
OptimizerConfig,
|
||||||
QATFinetuningConfig,
|
QATFinetuningConfig,
|
||||||
|
@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice:
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||||
|
|
||||||
|
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
|
||||||
|
|
||||||
self.job_uuid = job_uuid
|
self.job_uuid = job_uuid
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
|
@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._tokenizer = await self._setup_tokenizer()
|
self._tokenizer = await self._setup_tokenizer()
|
||||||
log.info("Tokenizer is initialized.")
|
log.info("Tokenizer is initialized.")
|
||||||
|
|
||||||
|
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
|
||||||
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||||
log.info("Optimizer is initialized.")
|
log.info("Optimizer is initialized.")
|
||||||
|
|
||||||
|
@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice:
|
||||||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||||
log.info("Loss is initialized.")
|
log.info("Loss is initialized.")
|
||||||
|
|
||||||
|
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||||
|
|
||||||
self._training_sampler, self._training_dataloader = await self._setup_data(
|
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||||
dataset_id=self.training_config.data_config.dataset_id,
|
dataset_id=self.training_config.data_config.dataset_id,
|
||||||
tokenizer=self._tokenizer,
|
tokenizer=self._tokenizer,
|
||||||
|
@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice:
|
||||||
"""
|
"""
|
||||||
The core training loop.
|
The core training loop.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
||||||
# Initialize tokens count and running loss (for grad accumulation)
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
running_loss: float = 0.0
|
running_loss: float = 0.0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue