mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Handle case when Customizer Job status is unknown (#1965)
# What does this PR do? This PR handles the case where a Customization Job's status is `unknown`. Since we don't map `unknown` to a valid `JobStatus`, the PostTraining provider throws an exception when fetching/listing a job. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_supervised_fine_tuning.py` succeeds [//]: # (## Documentation) Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
parent
6f97f9a593
commit
45e08ff417
2 changed files with 44 additions and 30 deletions
|
@ -27,11 +27,12 @@ from .models import _MODEL_ENTRIES
|
||||||
|
|
||||||
# Map API status to JobStatus enum
|
# Map API status to JobStatus enum
|
||||||
STATUS_MAPPING = {
|
STATUS_MAPPING = {
|
||||||
"running": "in_progress",
|
"running": JobStatus.in_progress.value,
|
||||||
"completed": "completed",
|
"completed": JobStatus.completed.value,
|
||||||
"failed": "failed",
|
"failed": JobStatus.failed.value,
|
||||||
"cancelled": "cancelled",
|
"cancelled": JobStatus.cancelled.value,
|
||||||
"pending": "scheduled",
|
"pending": JobStatus.scheduled.value,
|
||||||
|
"unknown": JobStatus.scheduled.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,10 +200,21 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_training_job_status(self):
|
def test_get_training_job_status(self):
|
||||||
|
customizer_status_to_job_status = [
|
||||||
|
("running", "in_progress"),
|
||||||
|
("completed", "completed"),
|
||||||
|
("failed", "failed"),
|
||||||
|
("cancelled", "cancelled"),
|
||||||
|
("pending", "scheduled"),
|
||||||
|
("unknown", "scheduled"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||||
|
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||||
self.mock_make_request.return_value = {
|
self.mock_make_request.return_value = {
|
||||||
"created_at": "2024-12-09T04:06:28.580220",
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
"updated_at": "2024-12-09T04:21:19.852832",
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
"status": "completed",
|
"status": customizer_status,
|
||||||
"steps_completed": 1210,
|
"steps_completed": 1210,
|
||||||
"epochs_completed": 2,
|
"epochs_completed": 2,
|
||||||
"percentage_done": 100.0,
|
"percentage_done": 100.0,
|
||||||
|
@ -217,7 +228,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||||
|
|
||||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||||
assert status.status.value == "completed"
|
assert status.status.value == expected_status
|
||||||
assert status.steps_completed == 1210
|
assert status.steps_completed == 1210
|
||||||
assert status.epochs_completed == 2
|
assert status.epochs_completed == 2
|
||||||
assert status.percentage_done == 100.0
|
assert status.percentage_done == 100.0
|
||||||
|
@ -225,9 +236,11 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
assert status.train_loss == 1.718016266822815
|
assert status.train_loss == 1.718016266822815
|
||||||
assert status.val_loss == 1.8661999702453613
|
assert status.val_loss == 1.8661999702453613
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
|
||||||
self._assert_request(
|
self._assert_request(
|
||||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
self.mock_make_request,
|
||||||
|
"GET",
|
||||||
|
f"/v1/customization/jobs/{job_id}/status",
|
||||||
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_training_jobs(self):
|
def test_get_training_jobs(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue