From 234f4e4583ef377b319fe3a5b38e18f99a190dec Mon Sep 17 00:00:00 2001 From: raspawar Date: Wed, 9 Apr 2025 11:10:01 +0000 Subject: [PATCH] add integration test --- tests/integration/datasets/test_datasets.py | 35 ++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index 18b31d39c..a744a8581 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -10,7 +10,7 @@ import mimetypes import os import pytest - +from llama_stack.apis.datasets import Dataset # How to run this test: # # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets @@ -94,3 +94,36 @@ def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, llama_stack_client.datasets.unregister(dataset.identifier) dataset_list = llama_stack_client.datasets.list() assert dataset.identifier not in [d.identifier for d in dataset_list] + +# nvidia provider only +@pytest.mark.parametrize( + "provider_id", + [ + "nvidia", + ], +) +def test_register_and_unregister(llama_stack_client, provider_id): + purpose = "eval/messages-answer" + source = { + "type": "uri", + "uri": "hf://datasets/llamastack/simpleqa?split=train", + } + dataset = llama_stack_client.datasets.register( + dataset_id=f"test-dataset-{provider_id}", + purpose=purpose, + source=source, + metadata={"provider": provider_id, "format": "json", "description": "Test dataset description"}, + ) + assert dataset.identifier is not None + assert dataset.provider_id == provider_id + assert dataset.identifier == f"test-dataset-{provider_id}" + + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.provider_id == provider_id] + assert any(provider_datasets) + assert any([d.identifier == f"test-dataset-{provider_id}" for d in provider_datasets]) + + llama_stack_client.datasets.unregister(dataset.identifier) + dataset_list = llama_stack_client.datasets.list() + provider_datasets = [d for d in dataset_list if d.identifier == dataset.identifier] + assert not any(provider_datasets) \ No newline at end of file