diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 1a588b025..1c03cc1b2 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -125,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel): validation_dataset_id: str algorithm: RLHFAlgorithm - algorithm_config: AlgorithmConfig + algorithm_config: DPOAlignmentConfig optimizer_config: OptimizerConfig training_config: TrainingConfig @@ -202,7 +202,7 @@ class PostTraining(Protocol): self, job_uuid: str, finetuned_model: str, - algorithm_config: AlgorithmConfig, + algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, hyperparam_search_config: dict[str, Any], logger_config: dict[str, Any],