From e6bbf8d20b87f46794bda659d8d15b0aa779f218 Mon Sep 17 00:00:00 2001 From: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:11:59 +0530 Subject: [PATCH] feat: Add NVIDIA NeMo datastore (#1852) # What does this PR do? Implemetation of NeMO Datastore register, unregister API. Open Issues: - provider_id gets set to `localfs` in client.datasets.register() as it is specified in routing_tables.py: DatasetsRoutingTable see: #1860 Currently I have passed `"provider_id":"nvidia"` in metadata and have parsed that in `DatasetsRoutingTable` (Not the best approach, but just a quick workaround to make it work for now.) ## Test Plan - Unit test cases: `pytest tests/unit/providers/nvidia/test_datastore.py` ```bash ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0, asyncio-0.26.0, nbval-0.11.0, metadata-3.1.1, html-4.1.1, cov-6.1.0 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 2 items tests/unit/providers/nvidia/test_datastore.py .. [100%] ============================================================ warnings summary ============================================================ ====================================================== 2 passed, 1 warning in 0.84s ====================================================== ``` cc: @dglogo, @mattf, @yanxi0830 --- .../self_hosted_distro/nvidia.md | 2 +- .../distribution/routers/routing_tables.py | 5 +- llama_stack/providers/registry/datasetio.py | 11 ++ .../remote/datasetio/nvidia/README.md | 74 ++++++++++ .../remote/datasetio/nvidia/__init__.py | 23 +++ .../remote/datasetio/nvidia/config.py | 61 ++++++++ .../remote/datasetio/nvidia/datasetio.py | 112 ++++++++++++++ llama_stack/templates/dependencies.json | 1 + llama_stack/templates/nvidia/build.yaml | 1 + llama_stack/templates/nvidia/nvidia.py | 9 +- .../templates/nvidia/run-with-safety.yaml | 7 + llama_stack/templates/nvidia/run.yaml | 12 +- pyproject.toml | 1 + .../integration/providers/nvidia/__init__.py | 5 + .../integration/providers/nvidia/conftest.py | 14 ++ .../providers/nvidia/test_datastore.py | 47 ++++++ tests/unit/providers/nvidia/test_datastore.py | 138 ++++++++++++++++++ 17 files changed, 514 insertions(+), 9 deletions(-) create mode 100644 llama_stack/providers/remote/datasetio/nvidia/README.md create mode 100644 llama_stack/providers/remote/datasetio/nvidia/__init__.py create mode 100644 llama_stack/providers/remote/datasetio/nvidia/config.py create mode 100644 llama_stack/providers/remote/datasetio/nvidia/datasetio.py create mode 100644 tests/integration/providers/nvidia/__init__.py create mode 100644 tests/integration/providers/nvidia/conftest.py create mode 100644 tests/integration/providers/nvidia/test_datastore.py create mode 100644 tests/unit/providers/nvidia/test_datastore.py diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 4407de779..a5bbbfdee 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -6,7 +6,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov | API | Provider(s) | |-----|-------------| | agents | `inline::meta-reference` | -| datasetio | `inline::localfs` | +| datasetio | `inline::localfs`, `remote::nvidia` | | eval | `remote::nvidia` | | inference | `remote::nvidia` | | post_training | `remote::nvidia` | diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 18b0c891f..68ee837bf 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -438,7 +438,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): provider_dataset_id = dataset_id # infer provider from source - if source.type == DatasetType.rows.value: + if metadata: + if metadata.get("provider_id"): + provider_id = metadata.get("provider_id") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: provider_id = "localfs" elif source.type == DatasetType.uri.value: # infer provider from uri diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index f83dcbc60..7db136136 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -36,4 +36,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", ), ), + remote_provider_spec( + api=Api.datasetio, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[ + "datasets", + ], + module="llama_stack.providers.remote.datasetio.nvidia", + config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/datasetio/nvidia/README.md b/llama_stack/providers/remote/datasetio/nvidia/README.md new file mode 100644 index 000000000..1d3d15132 --- /dev/null +++ b/llama_stack/providers/remote/datasetio/nvidia/README.md @@ -0,0 +1,74 @@ +# NVIDIA DatasetIO Provider for LlamaStack + +This provider enables dataset management using NVIDIA's NeMo Customizer service. + +## Features + +- Register datasets for fine-tuning LLMs +- Unregister datasets + +## Getting Started + +### Prerequisites + +- LlamaStack with NVIDIA configuration +- Access to Hosted NVIDIA NeMo Microservice +- API key for authentication with the NVIDIA service + +### Setup + +Build the NVIDIA environment: + +```bash +llama stack build --template nvidia --image-type conda +``` + +### Basic Usage using the LlamaStack Python Client + +#### Initialize the client + +```python +import os + +os.environ["NVIDIA_API_KEY"] = "your-api-key" +os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" +os.environ["NVIDIA_USER_ID"] = "llama-stack-user" +os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" +os.environ["NVIDIA_PROJECT_ID"] = "test-project" +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + +client = LlamaStackAsLibraryClient("nvidia") +client.initialize() +``` + +#### Register a dataset + +```python +client.datasets.register( + purpose="post-training/messages", + dataset_id="my-training-dataset", + source={"type": "uri", "uri": "hf://datasets/default/sample-dataset"}, + metadata={ + "format": "json", + "description": "Dataset for LLM fine-tuning", + "provider": "nvidia", + }, +) +``` + +#### Get a list of all registered datasets + +```python +datasets = client.datasets.list() +for dataset in datasets: + print(f"Dataset ID: {dataset.identifier}") + print(f"Description: {dataset.metadata.get('description', '')}") + print(f"Source: {dataset.source.uri}") + print("---") +``` + +#### Unregister a dataset + +```python +client.datasets.unregister(dataset_id="my-training-dataset") +``` diff --git a/llama_stack/providers/remote/datasetio/nvidia/__init__.py b/llama_stack/providers/remote/datasetio/nvidia/__init__.py new file mode 100644 index 000000000..418daec8d --- /dev/null +++ b/llama_stack/providers/remote/datasetio/nvidia/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import NvidiaDatasetIOConfig + + +async def get_adapter_impl( + config: NvidiaDatasetIOConfig, + _deps, +): + from .datasetio import NvidiaDatasetIOAdapter + + if not isinstance(config, NvidiaDatasetIOConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + + impl = NvidiaDatasetIOAdapter(config) + return impl + + +__all__ = ["get_adapter_impl", "NvidiaDatasetIOAdapter"] diff --git a/llama_stack/providers/remote/datasetio/nvidia/config.py b/llama_stack/providers/remote/datasetio/nvidia/config.py new file mode 100644 index 000000000..7f3dbdfbd --- /dev/null +++ b/llama_stack/providers/remote/datasetio/nvidia/config.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import warnings +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class NvidiaDatasetIOConfig(BaseModel): + """Configuration for NVIDIA DatasetIO implementation.""" + + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_API_KEY"), + description="The NVIDIA API key.", + ) + + dataset_namespace: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"), + description="The NVIDIA dataset namespace.", + ) + + project_id: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"), + description="The NVIDIA project ID.", + ) + + datasets_url: str = Field( + default_factory=lambda: os.getenv("NVIDIA_DATASETS_URL", "http://nemo.test"), + description="Base URL for the NeMo Dataset API", + ) + + # warning for default values + def __post_init__(self): + default_values = [] + if os.getenv("NVIDIA_PROJECT_ID") is None: + default_values.append("project_id='test-project'") + if os.getenv("NVIDIA_DATASET_NAMESPACE") is None: + default_values.append("dataset_namespace='default'") + if os.getenv("NVIDIA_DATASETS_URL") is None: + default_values.append("datasets_url='http://nemo.test'") + + if default_values: + warnings.warn( + f"Using default values: {', '.join(default_values)}. \ + Please set the environment variables to avoid this default behavior.", + stacklevel=2, + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "api_key": "${env.NVIDIA_API_KEY:}", + "dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}", + "project_id": "${env.NVIDIA_PROJECT_ID:test-project}", + "datasets_url": "${env.NVIDIA_DATASETS_URL:http://nemo.test}", + } diff --git a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py new file mode 100644 index 000000000..83efe3991 --- /dev/null +++ b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +import aiohttp + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.datasets import Dataset + +from .config import NvidiaDatasetIOConfig + + +class NvidiaDatasetIOAdapter: + """Nvidia NeMo DatasetIO API.""" + + def __init__(self, config: NvidiaDatasetIOConfig): + self.config = config + self.headers = {} + + async def _make_request( + self, + method: str, + path: str, + headers: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Dict[str, Any]: + """Helper method to make HTTP requests to the Customizer API.""" + url = f"{self.config.datasets_url}{path}" + request_headers = self.headers.copy() + + if headers: + request_headers.update(headers) + + async with aiohttp.ClientSession(headers=request_headers) as session: + async with session.request(method, url, params=params, json=json, **kwargs) as response: + if response.status != 200: + error_data = await response.json() + raise Exception(f"API request failed: {error_data}") + return await response.json() + + async def register_dataset( + self, + dataset_def: Dataset, + ) -> Dataset: + """Register a new dataset. + + Args: + dataset_def [Dataset]: The dataset definition. + dataset_id [str]: The ID of the dataset. + source [DataSource]: The source of the dataset. + metadata [Dict[str, Any]]: The metadata of the dataset. + format [str]: The format of the dataset. + description [str]: The description of the dataset. + Returns: + Dataset + """ + ## add warnings for unsupported params + request_body = { + "name": dataset_def.identifier, + "namespace": self.config.dataset_namespace, + "files_url": dataset_def.source.uri, + "project": self.config.project_id, + } + if dataset_def.metadata: + request_body["format"] = dataset_def.metadata.get("format") + request_body["description"] = dataset_def.metadata.get("description") + await self._make_request( + "POST", + "/v1/datasets", + json=request_body, + ) + return dataset_def + + async def update_dataset( + self, + dataset_id: str, + dataset_schema: Dict[str, ParamType], + url: URL, + provider_dataset_id: Optional[str] = None, + provider_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + raise NotImplementedError("Not implemented") + + async def unregister_dataset( + self, + dataset_id: str, + ) -> None: + await self._make_request( + "DELETE", + f"/v1/datasets/{self.config.dataset_namespace}/{dataset_id}", + headers={"Accept": "application/json", "Content-Type": "application/json"}, + ) + + async def iterrows( + self, + dataset_id: str, + start_index: Optional[int] = None, + limit: Optional[int] = None, + ) -> PaginatedResponse: + raise NotImplementedError("Not implemented") + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + raise NotImplementedError("Not implemented") diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 4c16411f0..1f25dda14 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -394,6 +394,7 @@ "aiosqlite", "blobfile", "chardet", + "datasets", "faiss-cpu", "fastapi", "fire", diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index a33fa3737..a05cf97ad 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -18,6 +18,7 @@ distribution_spec: - remote::nvidia datasetio: - inline::localfs + - remote::nvidia scoring: - inline::basic tool_runtime: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 463c13879..bfd004037 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -7,6 +7,7 @@ from pathlib import Path from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES @@ -23,7 +24,7 @@ def get_distribution_template() -> DistributionTemplate: "telemetry": ["inline::meta-reference"], "eval": ["remote::nvidia"], "post_training": ["remote::nvidia"], - "datasetio": ["inline::localfs"], + "datasetio": ["inline::localfs", "remote::nvidia"], "scoring": ["inline::basic"], "tool_runtime": ["inline::rag-runtime"], } @@ -38,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::nvidia", config=NVIDIASafetyConfig.sample_run_config(), ) + datasetio_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NvidiaDatasetIOConfig.sample_run_config(), + ) eval_provider = Provider( provider_id="nvidia", provider_type="remote::nvidia", @@ -75,6 +81,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "datasetio": [datasetio_provider], "eval": [eval_provider], }, default_models=default_models, diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index a3e5fefa4..5f594604b 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -74,6 +74,13 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db + - provider_id: nvidia + provider_type: remote::nvidia + config: + api_key: ${env.NVIDIA_API_KEY:} + dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default} + project_id: ${env.NVIDIA_PROJECT_ID:test-project} + datasets_url: ${env.NVIDIA_DATASETS_URL:http://nemo.test} scoring: - provider_id: basic provider_type: inline::basic diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 271ce1a16..b068e7f1e 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -62,13 +62,13 @@ providers: project_id: ${env.NVIDIA_PROJECT_ID:test-project} customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test} datasetio: - - provider_id: localfs - provider_type: inline::localfs + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db + api_key: ${env.NVIDIA_API_KEY:} + dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default} + project_id: ${env.NVIDIA_PROJECT_ID:test-project} + datasets_url: ${env.NVIDIA_DATASETS_URL:http://nemo.test} scoring: - provider_id: basic provider_type: inline::basic diff --git a/pyproject.toml b/pyproject.toml index 3424cf384..0f44ca053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -260,6 +260,7 @@ exclude = [ "^llama_stack/providers/inline/scoring/llm_as_judge/", "^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/datasetio/huggingface/", + "^llama_stack/providers/remote/datasetio/nvidia/", "^llama_stack/providers/remote/inference/anthropic/", "^llama_stack/providers/remote/inference/bedrock/", "^llama_stack/providers/remote/inference/cerebras/", diff --git a/tests/integration/providers/nvidia/__init__.py b/tests/integration/providers/nvidia/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/providers/nvidia/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/providers/nvidia/conftest.py b/tests/integration/providers/nvidia/conftest.py new file mode 100644 index 000000000..8beb113b0 --- /dev/null +++ b/tests/integration/providers/nvidia/conftest.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +# Skip all tests in this directory when running in GitHub Actions +in_github_actions = os.environ.get("GITHUB_ACTIONS") == "true" +if in_github_actions: + pytest.skip("Skipping NVIDIA tests in GitHub Actions environment", allow_module_level=True) diff --git a/tests/integration/providers/nvidia/test_datastore.py b/tests/integration/providers/nvidia/test_datastore.py new file mode 100644 index 000000000..5f96dee9f --- /dev/null +++ b/tests/integration/providers/nvidia/test_datastore.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +# How to run this test: +# +# LLAMA_STACK_CONFIG="nvidia" pytest -v tests/integration/providers/nvidia/test_datastore.py + + +# nvidia provider only +@pytest.mark.parametrize( + "provider_id", + [ + "nvidia", + ], +) +def test_register_and_unregister(llama_stack_client, provider_id): + purpose = "eval/messages-answer" + source = { + "type": "uri", + "uri": "hf://datasets/llamastack/simpleqa?split=train", + } + dataset_id = f"test-dataset-{provider_id}" + dataset = llama_stack_client.datasets.register( + dataset_id=dataset_id, + purpose=purpose, + source=source, + metadata={"provider_id": provider_id, "format": "json", "description": "Test dataset description"}, + ) + assert dataset.identifier is not None + assert dataset.provider_id == provider_id + assert dataset.identifier == dataset_id + + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.provider_id == provider_id] + assert any(provider_datasets) + assert any(d.identifier == dataset_id for d in provider_datasets) + + llama_stack_client.datasets.unregister(dataset.identifier) + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.identifier == dataset.identifier] + assert not any(provider_datasets) diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py new file mode 100644 index 000000000..a17e51a9c --- /dev/null +++ b/tests/unit/providers/nvidia/test_datastore.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from unittest.mock import patch + +import pytest + +from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource +from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig +from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter + + +class TestNvidiaDatastore(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" + + config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" + ) + self.adapter = NvidiaDatasetIOAdapter(config) + self.make_request_patcher = patch( + "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" + ) + self.mock_make_request = self.make_request_patcher.start() + + def tearDown(self): + self.make_request_patcher.stop() + + @pytest.fixture(autouse=True) + def inject_fixtures(self, run_async): + self.run_async = run_async + + def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): + """Helper method to verify request details in mock calls.""" + call_args = mock_call.call_args + + assert call_args[0][0] == expected_method + assert call_args[0][1] == expected_path + + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value + + def test_register_dataset(self): + self.mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "default", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, + ) + + self.run_async(self.adapter.register_dataset(dataset_def)) + + self.mock_make_request.assert_called_once() + self._assert_request( + self.mock_make_request, + "POST", + "/v1/datasets", + expected_json={ + "name": "test-dataset", + "namespace": "default", + "files_url": "https://example.com/data.jsonl", + "project": "default", + "format": "jsonl", + "description": "Test dataset description", + }, + ) + + def test_unregister_dataset(self): + self.mock_make_request.return_value = { + "message": "Resource deleted successfully.", + "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", + "deleted_at": None, + } + dataset_id = "test-dataset" + + self.run_async(self.adapter.unregister_dataset(dataset_id)) + + self.mock_make_request.assert_called_once() + self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") + + def test_register_dataset_with_custom_namespace_project(self): + custom_config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], + dataset_namespace="custom-namespace", + project_id="custom-project", + ) + custom_adapter = NvidiaDatasetIOAdapter(custom_config) + + self.mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "custom-namespace", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"format": "jsonl"}, + ) + + self.run_async(custom_adapter.register_dataset(dataset_def)) + + self.mock_make_request.assert_called_once() + self._assert_request( + self.mock_make_request, + "POST", + "/v1/datasets", + expected_json={ + "name": "test-dataset", + "namespace": "custom-namespace", + "files_url": "https://example.com/data.jsonl", + "project": "custom-project", + "format": "jsonl", + }, + ) + + +if __name__ == "__main__": + unittest.main()