diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ce6448951..f6860ea4b 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -104,10 +104,18 @@ class RLHFAlgorithm(Enum): dpo = "dpo" +@json_schema_type +class DPOLossType(Enum): + sigmoid = "sigmoid" + hinge = "hinge" + ipo = "ipo" + kto_pair = "kto_pair" + + @json_schema_type class DPOAlignmentConfig(BaseModel): beta: float - loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid" + loss_type: DPOLossType = DPOLossType.sigmoid @json_schema_type