Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Botao Chen
357141f6de refine 2025-03-09 18:14:26 -07:00
Botao Chen
41f86b5ce6 refine 2025-03-09 17:30:23 -07:00
Botao Chen
51282456b9 refine 2025-03-09 17:17:18 -07:00
Botao Chen
0b92fef3ba refine 2025-03-03 23:53:27 -08:00
Botao Chen
840fd353f7 init commit 2025-03-03 23:41:11 -08:00

View file

@ -16,81 +16,49 @@ from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type
class OptimizerType(Enum): class TrainingStrategy(BaseModel):
adam = "adam" # params that control Optimizer
adamw = "adamw" learning_rate: Optional[Union[float, Literal["auto"]]] = "auto"
sgd = "sgd" weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"
@json_schema_type # paramas that control how data is fed for training
class DatasetFormat(Enum): batch_size: Optional[Union[int, Literal["auto"]]] = "auto"
instruct = "instruct" shuffle: Optional[bool] = True
dialog = "dialog" n_epochs: Optional[int] = 3
# training loop control params
@json_schema_type max_training_steps: Optional[int] = None
class DataConfig(BaseModel): max_validation_steps: Optional[int] = None
dataset_id: str gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto"
batch_size: int
shuffle: bool # precision for training
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16" dtype: Optional[str] = "bf16"
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningStrategy(BaseModel):
type: Literal["LoRA"] = "LoRA" type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str] lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"]
apply_lora_to_mlp: bool apply_lora_to_mlp: Optional[bool] = True
apply_lora_to_output: bool apply_lora_to_output: Optional[bool] = False
rank: int rank: Optional[int] = 8
alpha: int alpha: Optional[int] = 16
use_dora: Optional[bool] = False use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False quantize_base: Optional[bool] = False
@json_schema_type @json_schema_type
class QATFinetuningConfig(BaseModel): class QATFinetuningStrategy(BaseModel):
type: Literal["QAT"] = "QAT" type: Literal["QAT"] = "QAT"
quantizer_name: str quantizer_name: str
group_size: int group_size: int
AlgorithmConfig = register_schema( AlgorithmStrategy = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")],
name="AlgorithmConfig", name="AlgorithmStrategy",
) )
@ -108,35 +76,13 @@ class RLHFAlgorithm(Enum):
@json_schema_type @json_schema_type
class DPOAlignmentConfig(BaseModel): class DPOAlignmentStrategy(BaseModel):
reward_scale: float reward_scale: float
reward_clip: float reward_clip: float
epsilon: float epsilon: float
gamma: float gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel): class PostTrainingJob(BaseModel):
job_uuid: str job_uuid: str
@ -176,26 +122,32 @@ class PostTraining(Protocol):
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
training_config: TrainingConfig, training_dataset_id: str,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field( model: str = Field(
default="Llama3.2-3B-Instruct", default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None, # Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: str, training_dataset_id: str,
algorithm_config: DPOAlignmentConfig, model: str = Field(
training_config: TrainingConfig, default="Llama3.2-3B-Instruct",
hyperparam_search_config: Dict[str, Any], description="Model descriptor from `llama model list`",
logger_config: Dict[str, Any], ),
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")