From 6a0c38f12369ab426f7abdc3c73b8802607f819a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 6 Mar 2025 19:34:02 +0000 Subject: [PATCH] add test cases for customizer --- llama_stack/templates/nvidia/run.yaml | 6 ++- tests/client-sdk/post_training/__init__.py | 5 +++ .../test_supervised_fine_tuning.py | 41 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 tests/client-sdk/post_training/__init__.py create mode 100644 tests/client-sdk/post_training/test_supervised_fine_tuning.py diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 8a7a40266..7559e518d 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -77,7 +77,11 @@ providers: post_training: - provider_id: nvidia-customizer provider_type: remote::nvidia - config: {} + config: + customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:} + user_id: ${env.NVIDIA_USER_ID:} + project_id: ${env.NVIDIA_PROJECT_ID:} + dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:} tool_runtime: - provider_id: rag-runtime provider_type: inline::rag-runtime diff --git a/tests/client-sdk/post_training/__init__.py b/tests/client-sdk/post_training/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/client-sdk/post_training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/client-sdk/post_training/test_supervised_fine_tuning.py b/tests/client-sdk/post_training/test_supervised_fine_tuning.py new file mode 100644 index 000000000..83e8da461 --- /dev/null +++ b/tests/client-sdk/post_training/test_supervised_fine_tuning.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +POST_TRAINING_PROVIDER_TYPES = ["remote::nvidia"] + + +@pytest.fixture(scope="session") +def post_training_provider_available(llama_stack_client): + providers = llama_stack_client.providers.list() + post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES] + return len(post_training_providers) > 0 + + +def test_post_training_provider_registration(llama_stack_client, post_training_provider_available): + """Check if post_training is in the api list. + This is a sanity check to ensure the provider is registered.""" + if not post_training_provider_available: + pytest.skip("post training provider not available") + + providers = llama_stack_client.providers.list() + post_training_providers = [p for p in providers if p.provider_type in POST_TRAINING_PROVIDER_TYPES] + + assert len(post_training_providers) > 0 + + assert any("post_training" in provider.api for provider in post_training_providers) + + +def test_list_training_jobs(llama_stack_client, post_training_provider_available): + """Check if the list_jobs method returns a list of jobs.""" + if not post_training_provider_available: + pytest.skip("post training provider not available") + + jobs = llama_stack_client.post_training.job.list() + + assert jobs is not None + assert isinstance(jobs, list)