This commit is contained in:
Botao Chen 2025-03-09 17:17:18 -07:00
parent 0b92fef3ba
commit 51282456b9

View file

@ -16,43 +16,29 @@ from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: Optional[int] = 1
shuffle: Optional[bool] = True
data_format: Optional[DatasetFormat] = DatasetFormat.instruct
validation_dataset_id: Optional[str] = None
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
class TrainingStrategy(BaseModel):
# params that control Optimizer
lr: Optional[float] = 2e-5
weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[int] = 20
@json_schema_type
class TrainingConfig(BaseModel):
data_config: DataConfig
optimizer_config: Optional[OptimizerConfig] = OptimizerConfig()
n_epochs: Optional[int] = 1
max_steps_per_epoch: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1
num_warmup_steps: Optional[int] = 0
# paramas that control how data is fed for training
batch_size: Optional[int] = 1
shuffle: Optional[bool] = True
n_epochs: Optional[int] = 3
# training loop control params
max_training_steps: Optional[int] = None
max_validation_steps: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1
# precision for training
dtype: Optional[str] = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
class LoraFinetuningStrategy(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"]
apply_lora_to_mlp: Optional[bool] = True
@ -64,15 +50,15 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type
class QATFinetuningConfig(BaseModel):
class QATFinetuningStrategy(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
AlgorithmStrategy = register_schema(
Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")],
name="AlgorithmStrategy",
)
@ -90,35 +76,13 @@ class RLHFAlgorithm(Enum):
@json_schema_type
class DPOAlignmentConfig(BaseModel):
class DPOAlignmentStrategy(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@ -158,24 +122,32 @@ class PostTraining(Protocol):
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
training_dataset_id: str,
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
training_dataset_id: str,
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
# Optional
validation_dataset_id: Optional[str] = None,
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")