feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

View file

@ -8,14 +8,12 @@ from typing import List
import pytest
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
@ -84,7 +82,6 @@ class TestPostTraining:
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@ -93,7 +90,6 @@ class TestPostTraining:
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"

View file

@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfigOptimizerConfig,
)
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
ListPostTrainingJobsResponse,
PostTrainingJob,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
)
@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase):
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
found = False
for call_args in mock_call.call_args_list:
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
if call_args[0] == (expected_method, expected_path):
found = True
else:
if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path:
found = True
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
if call_args[1]["params"] == expected_params:
found = True
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
if expected_json:
for key, value in expected_json.items():
if call_args[1]["json"][key] == value:
found = True
assert found
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
assert isinstance(training_job, PostTrainingJob)
assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
def test_list_post_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
jobs = self.run_async(self.adapter.list_post_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert isinstance(jobs, ListPostTrainingJobsResponse)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.id == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled))
assert result.id == job_id
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",