diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 9f0d887a7..960acf775 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -107,7 +107,8 @@ class TestPostTraining: logger.info(f"Starting training job with UUID: {self.job_uuid}") # train with HF trl SFTTrainer as the default - os.makedirs("~/.llama/checkpoints/", exist_ok=True) + checkpoint_dir = os.path.expanduser("/mnt/") + # os.makedirs(checkpoint_dir, exist_ok=True) started = datetime.now(timezone.utc) _ = llama_stack_client.post_training.supervised_fine_tune( @@ -117,7 +118,7 @@ class TestPostTraining: training_config=training_config, hyperparam_search_config={}, logger_config={}, - checkpoint_dir="~/.llama/checkpoints/", + checkpoint_dir=checkpoint_dir, ) while True: