diff --git a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py index 5d3cf5c03..83efe3991 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py +++ b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py @@ -12,7 +12,6 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.datasets import Dataset -from llama_stack.schema_utils import webmethod from .config import NvidiaDatasetIOConfig @@ -54,14 +53,14 @@ class NvidiaDatasetIOAdapter: """Register a new dataset. Args: - dataset_def : The dataset definition. - dataset_id: The ID of the dataset. - source: The source of the dataset. - metadata: The metadata of the dataset. - format: The format of the dataset. - description: The description of the dataset. + dataset_def [Dataset]: The dataset definition. + dataset_id [str]: The ID of the dataset. + source [DataSource]: The source of the dataset. + metadata [Dict[str, Any]]: The metadata of the dataset. + format [str]: The format of the dataset. + description [str]: The description of the dataset. Returns: - None + Dataset """ ## add warnings for unsupported params request_body = { @@ -78,8 +77,8 @@ class NvidiaDatasetIOAdapter: "/v1/datasets", json=request_body, ) + return dataset_def - @webmethod(route="/datasets/{dataset_id:path}", method="POST") async def update_dataset( self, dataset_id: str, @@ -91,7 +90,6 @@ class NvidiaDatasetIOAdapter: ) -> None: raise NotImplementedError("Not implemented") - @webmethod(route="/datasets/{dataset_id:path}", method="DELETE") async def unregister_dataset( self, dataset_id: str, diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py new file mode 100644 index 000000000..eb4180b5e --- /dev/null +++ b/tests/unit/providers/nvidia/test_datastore.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from unittest.mock import patch + +import pytest + +from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource +from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig +from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter + + +class TestNvidiaDatastore(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" + + config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" + ) + self.adapter = NvidiaDatasetIOAdapter(config) + self.make_request_patcher = patch( + "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" + ) + self.mock_make_request = self.make_request_patcher.start() + + def tearDown(self): + self.make_request_patcher.stop() + + @pytest.fixture(autouse=True) + def inject_fixtures(self, run_async): + self.run_async = run_async + + def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): + """Helper method to verify request details in mock calls.""" + call_args = mock_call.call_args + + assert call_args[0][0] == expected_method + assert call_args[0][1] == expected_path + + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value + + def test_register_dataset(self): + self.mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "default", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"format": "jsonl", "description": "Test dataset description"}, + ) + + self.run_async(self.adapter.register_dataset(dataset_def)) + + self.mock_make_request.assert_called_once() + self._assert_request( + self.mock_make_request, + "POST", + "/v1/datasets", + expected_json={ + "name": "test-dataset", + "namespace": "default", + "files_url": "https://example.com/data.jsonl", + "project": "default", + "format": "jsonl", + "description": "Test dataset description", + }, + ) + + def test_unregister_dataset(self): + self.mock_make_request.return_value = { + "message": "Resource deleted successfully.", + "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", + "deleted_at": None, + } + dataset_id = "test-dataset" + + self.run_async(self.adapter.unregister_dataset(dataset_id)) + + self.mock_make_request.assert_called_once() + self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") + + +if __name__ == "__main__": + unittest.main()