diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index db5c57821..d7801ba1c 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -14470,28 +14470,31 @@
"DPOAlignmentConfig": {
"type": "object",
"properties": {
- "reward_scale": {
+ "beta": {
"type": "number"
},
- "reward_clip": {
- "type": "number"
- },
- "epsilon": {
- "type": "number"
- },
- "gamma": {
- "type": "number"
+ "loss_type": {
+ "$ref": "#/components/schemas/DPOLossType",
+ "default": "sigmoid"
}
},
"additionalProperties": false,
"required": [
- "reward_scale",
- "reward_clip",
- "epsilon",
- "gamma"
+ "beta",
+ "loss_type"
],
"title": "DPOAlignmentConfig"
},
+ "DPOLossType": {
+ "type": "string",
+ "enum": [
+ "sigmoid",
+ "hinge",
+ "ipo",
+ "kto_pair"
+ ],
+ "title": "DPOLossType"
+ },
"DataConfig": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 29ba9dede..be02e1e42 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -10111,21 +10111,24 @@ components:
DPOAlignmentConfig:
type: object
properties:
- reward_scale:
- type: number
- reward_clip:
- type: number
- epsilon:
- type: number
- gamma:
+ beta:
type: number
+ loss_type:
+ $ref: '#/components/schemas/DPOLossType'
+ default: sigmoid
additionalProperties: false
required:
- - reward_scale
- - reward_clip
- - epsilon
- - gamma
+ - beta
+ - loss_type
title: DPOAlignmentConfig
+ DPOLossType:
+ type: string
+ enum:
+ - sigmoid
+ - hinge
+ - ipo
+ - kto_pair
+ title: DPOLossType
DataConfig:
type: object
properties:
diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py
index b196c8a17..f6860ea4b 100644
--- a/llama_stack/apis/post_training/post_training.py
+++ b/llama_stack/apis/post_training/post_training.py
@@ -104,12 +104,18 @@ class RLHFAlgorithm(Enum):
dpo = "dpo"
+@json_schema_type
+class DPOLossType(Enum):
+ sigmoid = "sigmoid"
+ hinge = "hinge"
+ ipo = "ipo"
+ kto_pair = "kto_pair"
+
+
@json_schema_type
class DPOAlignmentConfig(BaseModel):
- reward_scale: float
- reward_clip: float
- epsilon: float
- gamma: float
+ beta: float
+ loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type