This commit is contained in:
Botao Chen 2024-12-13 14:55:01 -08:00
parent d55a8343ea
commit d0a72cc288
3 changed files with 7 additions and 3 deletions

View file

@ -66,6 +66,7 @@ class TrainingConfig(BaseModel):
@json_schema_type
class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
@ -77,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type
class QATFinetuningConfig(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
]

View file

@ -39,8 +39,9 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs_list:
raise ValueError(f"Job {job_uuid} already exists")
for job in self.jobs_list:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)

View file

@ -19,6 +19,7 @@ class TestPostTraining:
@pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,