mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 03:23:53 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
|
|
@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTrainingJob,
|
||||
)
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
NvidiaPostTrainingConfig,
|
||||
NvidiaPostTrainingJob,
|
||||
NvidiaPostTrainingJobStatusResponse,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
found = False
|
||||
for call_args in mock_call.call_args_list:
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
if call_args[0] == (expected_method, expected_path):
|
||||
found = True
|
||||
else:
|
||||
if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path:
|
||||
found = True
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
if expected_params:
|
||||
if call_args[1]["params"] == expected_params:
|
||||
found = True
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if call_args[1]["json"][key] == value:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
def test_supervised_fine_tune(self):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
|
|
@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
assert isinstance(training_job, PostTrainingJob)
|
||||
assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
|
|
@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": "completed",
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == "completed"
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
def test_list_post_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
|
|
@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
]
|
||||
}
|
||||
|
||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||
jobs = self.run_async(self.adapter.list_post_training_jobs())
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert isinstance(jobs, ListPostTrainingJobsResponse)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.id == job_id
|
||||
assert job.status.value == "completed"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
|
|
@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_cancel_training_job(self):
|
||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||
result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled))
|
||||
assert result.id == job_id
|
||||
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue