mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 13:02:16 +00:00
parameter validation, test cases
This commit is contained in:
parent
d7340da7a6
commit
87ce96c1f7
4 changed files with 453 additions and 70 deletions
|
|
@ -9,6 +9,8 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: add default values for all fields
|
||||
|
||||
|
||||
class NvidiaPostTrainingConfig(BaseModel):
|
||||
"""Configuration for NVIDIA Post Training implementation."""
|
||||
|
|
@ -58,3 +60,54 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
|
||||
}
|
||||
|
||||
|
||||
class SFTLoRADefaultConfig(BaseModel):
|
||||
"""NVIDIA-specific training configuration with default values."""
|
||||
|
||||
# ToDo: split into SFT and LoRA configs??
|
||||
|
||||
# General training parameters
|
||||
n_epochs: int = 50
|
||||
|
||||
# NeMo customizer specific parameters
|
||||
log_every_n_steps: Optional[int] = None
|
||||
val_check_interval: float = 0.25
|
||||
sequence_packing_enabled: bool = False
|
||||
weight_decay: float = 0.01
|
||||
lr: float = 0.0001
|
||||
|
||||
# SFT specific parameters
|
||||
hidden_dropout: Optional[float] = None
|
||||
attention_dropout: Optional[float] = None
|
||||
ffn_dropout: Optional[float] = None
|
||||
|
||||
# LoRA default parameters
|
||||
lora_adapter_dim: int = 8
|
||||
lora_adapter_dropout: Optional[float] = None
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Data config
|
||||
batch_size: int = 8
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
"""Return a sample configuration for NVIDIA training."""
|
||||
return {
|
||||
"n_epochs": 50,
|
||||
"log_every_n_steps": 10,
|
||||
"val_check_interval": 0.25,
|
||||
"sequence_packing_enabled": False,
|
||||
"weight_decay": 0.01,
|
||||
"hidden_dropout": 0.1,
|
||||
"attention_dropout": 0.1,
|
||||
"lora_adapter_dim": 8,
|
||||
"lora_alpha": 16,
|
||||
"data_config": {
|
||||
"dataset_id": "default",
|
||||
"batch_size": 8,
|
||||
},
|
||||
"optimizer_config": {
|
||||
"lr": 0.0001,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue