even more fixes

This commit is contained in:
Ashwin Bharambe 2025-08-05 14:52:09 -07:00
parent 0b6a7abb28
commit 882176928f
16 changed files with 28 additions and 18 deletions

View file

@ -154,6 +154,7 @@ providers:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -154,6 +154,7 @@ providers:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import Any, Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
@ -72,8 +71,13 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = Field(default_factory=lambda: tempfile.mkdtemp(prefix="dpo_output_"))
dpo_output_dir: str
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
return {
"checkpoint_format": "huggingface",
"distributed_backend": None,
"device": "cpu",
"dpo_output_dir": __distro_dir__ + "/dpo_output",
}