fix: Add missing DPO discriminator for remote provider serialization

This commit is contained in:
Nehanth 2025-06-25 15:52:04 -04:00
parent fa0b0c13d4
commit 40a0cef38e

View file

@ -87,7 +87,16 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
@json_schema_type
class DPOAlignmentConfig(BaseModel):
type: Literal["DPO"] = "DPO"
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig | DPOAlignmentConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -104,14 +113,6 @@ class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
@ -124,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel):
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
algorithm_config: AlgorithmConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
@ -201,7 +202,7 @@ class PostTraining(Protocol):
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
algorithm_config: AlgorithmConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],