mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Add missing DPO discriminator for remote provider serialization
This commit is contained in:
parent
fa0b0c13d4
commit
40a0cef38e
1 changed files with 12 additions and 11 deletions
|
@ -87,7 +87,16 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
@json_schema_type
|
||||||
|
class DPOAlignmentConfig(BaseModel):
|
||||||
|
type: Literal["DPO"] = "DPO"
|
||||||
|
reward_scale: float
|
||||||
|
reward_clip: float
|
||||||
|
epsilon: float
|
||||||
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig | DPOAlignmentConfig, Field(discriminator="type")]
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,14 +113,6 @@ class RLHFAlgorithm(Enum):
|
||||||
dpo = "dpo"
|
dpo = "dpo"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DPOAlignmentConfig(BaseModel):
|
|
||||||
reward_scale: float
|
|
||||||
reward_clip: float
|
|
||||||
epsilon: float
|
|
||||||
gamma: float
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingRLHFRequest(BaseModel):
|
class PostTrainingRLHFRequest(BaseModel):
|
||||||
"""Request to finetune a model."""
|
"""Request to finetune a model."""
|
||||||
|
@ -124,7 +125,7 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
validation_dataset_id: str
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
algorithm: RLHFAlgorithm
|
||||||
algorithm_config: DPOAlignmentConfig
|
algorithm_config: AlgorithmConfig
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: OptimizerConfig
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
@ -201,7 +202,7 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: str,
|
finetuned_model: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: AlgorithmConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue