forked from phoenix-oss/llama-stack-mirror
[2/n][torchtune integration] implement job management and return training artifacts (#593)
### Context In this PR, we - Implement the post training job management and get training artifacts apis - get_training_jobs - get_training_job_status - get_training_job_artifacts - get_training_job_logstream is deleted since the trace can be directly accessed by UI with Jaeger https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces - Refactor the post training and training types definition to make them more intuitive. - Rewrite the checkpointer to make it compatible with llama-stack file system and can be recognized during inference ### Test Unit test `pytest llama_stack/providers/tests/post_training/test_post_training.py -m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short --disable-warnings` <img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM" src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a"> e2e test with client side call <img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM" src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
parent
5764a95912
commit
c294a01c4b
8 changed files with 331 additions and 67 deletions
|
@ -19,6 +19,7 @@ class TestPostTraining:
|
|||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
|
@ -59,3 +60,33 @@ class TestPostTraining:
|
|||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, List)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert (
|
||||
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||
in job_artifacts.checkpoints[0].path
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue