From a3c07ac10a6c08e5e87163e5b3d9b7b25cf66d2d Mon Sep 17 00:00:00 2001 From: raspawar Date: Wed, 9 Apr 2025 12:25:13 +0000 Subject: [PATCH] update tests --- tests/integration/datasets/test_datasets.py | 35 +------------- .../integration/providers/nvidia/__init__.py | 5 ++ .../providers/nvidia/test_datastore.py | 47 +++++++++++++++++++ tests/unit/providers/nvidia/test_datastore.py | 40 ++++++++++++++++ 4 files changed, 93 insertions(+), 34 deletions(-) create mode 100644 tests/integration/providers/nvidia/__init__.py create mode 100644 tests/integration/providers/nvidia/test_datastore.py diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index a744a8581..18b31d39c 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -10,7 +10,7 @@ import mimetypes import os import pytest -from llama_stack.apis.datasets import Dataset + # How to run this test: # # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets @@ -94,36 +94,3 @@ def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, llama_stack_client.datasets.unregister(dataset.identifier) dataset_list = llama_stack_client.datasets.list() assert dataset.identifier not in [d.identifier for d in dataset_list] - -# nvidia provider only -@pytest.mark.parametrize( - "provider_id", - [ - "nvidia", - ], -) -def test_register_and_unregister(llama_stack_client, provider_id): - purpose = "eval/messages-answer" - source = { - "type": "uri", - "uri": "hf://datasets/llamastack/simpleqa?split=train", - } - dataset = llama_stack_client.datasets.register( - dataset_id=f"test-dataset-{provider_id}", - purpose=purpose, - source=source, - metadata={"provider": provider_id, "format": "json", "description": "Test dataset description"}, - ) - assert dataset.identifier is not None - assert dataset.provider_id == provider_id - assert dataset.identifier == f"test-dataset-{provider_id}" - - dataset_list = llama_stack_client.datasets.list() - provider_datasets = [d for d in dataset_list if d.provider_id == provider_id] - assert any(provider_datasets) - assert any([d.identifier == f"test-dataset-{provider_id}" for d in provider_datasets]) - - llama_stack_client.datasets.unregister(dataset.identifier) - dataset_list = llama_stack_client.datasets.list() - provider_datasets = [d for d in dataset_list if d.identifier == dataset.identifier] - assert not any(provider_datasets) \ No newline at end of file diff --git a/tests/integration/providers/nvidia/__init__.py b/tests/integration/providers/nvidia/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/providers/nvidia/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/providers/nvidia/test_datastore.py b/tests/integration/providers/nvidia/test_datastore.py new file mode 100644 index 000000000..f6ccd55a6 --- /dev/null +++ b/tests/integration/providers/nvidia/test_datastore.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +# How to run this test: +# +# LLAMA_STACK_CONFIG="nvidia" pytest -v tests/integration/providers/nvidia/test_datastore.py + + +# nvidia provider only +@pytest.mark.parametrize( + "provider_id", + [ + "nvidia", + ], +) +def test_register_and_unregister(llama_stack_client, provider_id): + purpose = "eval/messages-answer" + source = { + "type": "uri", + "uri": "hf://datasets/llamastack/simpleqa?split=train", + } + dataset_id = f"test-dataset-{provider_id}" + dataset = llama_stack_client.datasets.register( + dataset_id=dataset_id, + purpose=purpose, + source=source, + metadata={"provider": provider_id, "format": "json", "description": "Test dataset description"}, + ) + assert dataset.identifier is not None + assert dataset.provider_id == provider_id + assert dataset.identifier == dataset_id + + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.provider_id == provider_id] + assert any(provider_datasets) + assert any(d.identifier == dataset_id for d in provider_datasets) + + llama_stack_client.datasets.unregister(dataset.identifier) + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.identifier == dataset.identifier] + assert not any(provider_datasets) diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py index eb4180b5e..504f1af5d 100644 --- a/tests/unit/providers/nvidia/test_datastore.py +++ b/tests/unit/providers/nvidia/test_datastore.py @@ -93,6 +93,46 @@ class TestNvidiaDatastore(unittest.TestCase): self.mock_make_request.assert_called_once() self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") + def test_register_dataset_with_custom_namespace_project(self): + custom_config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], + dataset_namespace="custom-namespace", + project_id="custom-project", + ) + custom_adapter = NvidiaDatasetIOAdapter(custom_config) + + self.mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "custom-namespace", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"format": "jsonl"}, + ) + + self.run_async(custom_adapter.register_dataset(dataset_def)) + + self.mock_make_request.assert_called_once() + self._assert_request( + self.mock_make_request, + "POST", + "/v1/datasets", + expected_json={ + "name": "test-dataset", + "namespace": "custom-namespace", + "files_url": "https://example.com/data.jsonl", + "project": "custom-project", + "format": "jsonl", + }, + ) + if __name__ == "__main__": unittest.main()