diff --git a/tests/client-sdk/post_training/test_supervied_fine_tuning.py b/tests/client-sdk/post_training/test_supervied_fine_tuning.py new file mode 100644 index 000000000..232510478 --- /dev/null +++ b/tests/client-sdk/post_training/test_supervied_fine_tuning.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"] + + +@pytest.mark.integration +@pytest.fixture(scope="session") +def post_training_provider_available(llama_stack_client): + providers = llama_stack_client.providers.list() + post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES] + return len(post_training_providers) > 0 + + +@pytest.mark.integration +def test_post_training_provider_registration(llama_stack_client, post_training_provider_available): + """Check if post_training is in the api list. + This is a sanity check to ensure the provider is registered.""" + if not post_training_provider_available: + pytest.skip("post training provider not available") + + providers = llama_stack_client.providers.list() + post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES] + assert len(post_training_providers) > 0 + + +@pytest.mark.integration +def test_get_training_jobs(llama_stack_client, post_training_provider_available): + """Test listing all training jobs.""" + if not post_training_provider_available: + pytest.skip("post training provider not available") + + jobs = llama_stack_client.post_training.get_training_jobs() + assert isinstance(jobs, dict) + assert "data" in jobs + assert isinstance(jobs["data"], list) + + +@pytest.mark.integration +def test_get_training_job_status(llama_stack_client, post_training_provider_available): + """Test getting status of a specific training job.""" + if not post_training_provider_available: + pytest.skip("post training provider not available") + + jobs = llama_stack_client.post_training.get_training_jobs() + if not jobs["data"]: + pytest.skip("No training jobs available to check status") + + job_uuid = jobs["data"][0]["job_uuid"] + job_status = llama_stack_client.post_training.get_training_job_status(job_uuid=job_uuid) + + assert job_status is not None + assert "job_uuid" in job_status + assert "status" in job_status + assert job_status["job_uuid"] == job_uuid