diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index b196c8a17..1a588b025 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -87,7 +87,16 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] +@json_schema_type +class DPOAlignmentConfig(BaseModel): + type: Literal["DPO"] = "DPO" + reward_scale: float + reward_clip: float + epsilon: float + gamma: float + + +AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig | DPOAlignmentConfig, Field(discriminator="type")] register_schema(AlgorithmConfig, name="AlgorithmConfig") @@ -104,14 +113,6 @@ class RLHFAlgorithm(Enum): dpo = "dpo" -@json_schema_type -class DPOAlignmentConfig(BaseModel): - reward_scale: float - reward_clip: float - epsilon: float - gamma: float - - @json_schema_type class PostTrainingRLHFRequest(BaseModel): """Request to finetune a model.""" @@ -124,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel): validation_dataset_id: str algorithm: RLHFAlgorithm - algorithm_config: DPOAlignmentConfig + algorithm_config: AlgorithmConfig optimizer_config: OptimizerConfig training_config: TrainingConfig @@ -201,7 +202,7 @@ class PostTraining(Protocol): self, job_uuid: str, finetuned_model: str, - algorithm_config: DPOAlignmentConfig, + algorithm_config: AlgorithmConfig, training_config: TrainingConfig, hyperparam_search_config: dict[str, Any], logger_config: dict[str, Any],