Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Botao Chen
357141f6de refine 2025-03-09 18:14:26 -07:00
Botao Chen
41f86b5ce6 refine 2025-03-09 17:30:23 -07:00
Botao Chen
51282456b9 refine 2025-03-09 17:17:18 -07:00
Botao Chen
0b92fef3ba refine 2025-03-03 23:53:27 -08:00
Botao Chen
840fd353f7 init commit 2025-03-03 23:41:11 -08:00

View file

@ -16,81 +16,49 @@ from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
class TrainingStrategy(BaseModel):
# params that control Optimizer
learning_rate: Optional[Union[float, Literal["auto"]]] = "auto"
weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"
# paramas that control how data is fed for training
batch_size: Optional[Union[int, Literal["auto"]]] = "auto"
shuffle: Optional[bool] = True
n_epochs: Optional[int] = 3
@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
dialog = "dialog"
# training loop control params
max_training_steps: Optional[int] = None
max_validation_steps: Optional[int] = None
gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None
# precision for training
dtype: Optional[str] = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
class LoraFinetuningStrategy(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"]
apply_lora_to_mlp: Optional[bool] = True
apply_lora_to_output: Optional[bool] = False
rank: Optional[int] = 8
alpha: Optional[int] = 16
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
class QATFinetuningStrategy(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
AlgorithmStrategy = register_schema(
Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")],
name="AlgorithmStrategy",
)
@ -108,35 +76,13 @@ class RLHFAlgorithm(Enum):
@json_schema_type
class DPOAlignmentConfig(BaseModel):
class DPOAlignmentStrategy(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@ -176,26 +122,32 @@ class PostTraining(Protocol):
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
training_dataset_id: str,
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
training_dataset_id: str,
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")