This commit is contained in:
Botao Chen 2025-03-09 17:17:18 -07:00
parent 0b92fef3ba
commit 51282456b9

View file

@ -16,43 +16,29 @@ from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type
class DatasetFormat(Enum): class TrainingStrategy(BaseModel):
instruct = "instruct" # params that control Optimizer
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: Optional[int] = 1
shuffle: Optional[bool] = True
data_format: Optional[DatasetFormat] = DatasetFormat.instruct
validation_dataset_id: Optional[str] = None
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
lr: Optional[float] = 2e-5 lr: Optional[float] = 2e-5
weight_decay: Optional[float] = 0.1 weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[int] = 20 num_warmup_steps: Optional[int] = 0
# paramas that control how data is fed for training
batch_size: Optional[int] = 1
shuffle: Optional[bool] = True
n_epochs: Optional[int] = 3
@json_schema_type # training loop control params
class TrainingConfig(BaseModel): max_training_steps: Optional[int] = None
data_config: DataConfig
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig()
n_epochs: Optional[int] = 1
max_steps_per_epoch: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1
max_validation_steps: Optional[int] = None max_validation_steps: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1
# precision for training
dtype: Optional[str] = "bf16" dtype: Optional[str] = "bf16"
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningStrategy(BaseModel):
type: Literal["LoRA"] = "LoRA" type: Literal["LoRA"] = "LoRA"
lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"]
apply_lora_to_mlp: Optional[bool] = True apply_lora_to_mlp: Optional[bool] = True
@ -64,15 +50,15 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type @json_schema_type
class QATFinetuningConfig(BaseModel): class QATFinetuningStrategy(BaseModel):
type: Literal["QAT"] = "QAT" type: Literal["QAT"] = "QAT"
quantizer_name: str quantizer_name: str
group_size: int group_size: int
AlgorithmConfig = register_schema( AlgorithmStrategy = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")],
name="AlgorithmConfig", name="AlgorithmStrategy",
) )
@ -90,35 +76,13 @@ class RLHFAlgorithm(Enum):
@json_schema_type @json_schema_type
class DPOAlignmentConfig(BaseModel): class DPOAlignmentStrategy(BaseModel):
reward_scale: float reward_scale: float
reward_clip: float reward_clip: float
epsilon: float epsilon: float
gamma: float gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel): class PostTrainingJob(BaseModel):
job_uuid: str job_uuid: str
@ -158,24 +122,32 @@ class PostTraining(Protocol):
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
training_config: TrainingConfig, training_dataset_id: str,
model: str = Field( model: str = Field(
default="Llama3.2-3B-Instruct", default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None, # Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: str, training_dataset_id: str,
algorithm_config: DPOAlignmentConfig, model: str = Field(
training_config: TrainingConfig, default="Llama3.2-3B-Instruct",
hyperparam_search_config: Dict[str, Any], description="Model descriptor from `llama model list`",
logger_config: Dict[str, Any], ),
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")