mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: make training config fields optional (#1861)
# What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
70a7e4d51e
commit
0751a960a5
4 changed files with 29 additions and 21 deletions
|
@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
|
|||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: Optional[int] = 1
|
||||
data_config: Optional[DataConfig] = None
|
||||
optimizer_config: Optional[OptimizerConfig] = None
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
@ -177,9 +177,9 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue