diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 74243afc6..fdbaa364d 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -66,6 +66,7 @@ class TrainingConfig(BaseModel): @json_schema_type class LoraFinetuningConfig(BaseModel): + type: Literal["LoRA"] = "LoRA" lora_attn_modules: List[str] apply_lora_to_mlp: bool apply_lora_to_output: bool @@ -77,12 +78,13 @@ class LoraFinetuningConfig(BaseModel): @json_schema_type class QATFinetuningConfig(BaseModel): + type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int AlgorithmConfig = Annotated[ - Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type") + Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type") ] diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 4306752e1..9b1269f16 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -39,8 +39,9 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - if job_uuid in self.jobs_list: - raise ValueError(f"Job {job_uuid} already exists") + for job in self.jobs_list: + if job_uuid == job.job_uuid: + raise ValueError(f"Job {job_uuid} already exists") post_training_job = PostTrainingJob(job_uuid=job_uuid) diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 6959f9f9c..4ecc05187 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -19,6 +19,7 @@ class TestPostTraining: @pytest.mark.asyncio async def test_supervised_fine_tune(self, post_training_stack): algorithm_config = LoraFinetuningConfig( + type="LoRA", lora_attn_modules=["q_proj", "v_proj", "output_proj"], apply_lora_to_mlp=True, apply_lora_to_output=False,