From 9da092ff2d7a0ec4d388334510e954c10a008f7c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 16:17:22 -0700 Subject: [PATCH] test case --- tests/integration/datasets/test_datasets.py | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index 0715be06e..be02705bc 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -16,16 +16,26 @@ import pytest # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets -def test_register_dataset(llama_stack_client): +@pytest.mark.parametrize( + "purpose, source, provider_id", + [ + ( + "eval/messages-answer", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + "huggingface", + ), + ], +) +def test_register_dataset(llama_stack_client, purpose, source, provider_id): dataset = llama_stack_client.datasets.register( - purpose="eval/messages-answer", - source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", - }, + purpose=purpose, + source=source, ) assert dataset.identifier is not None - assert dataset.provider_id == "huggingface" + assert dataset.provider_id == provider_id iterrow_response = llama_stack_client.datasets.iterrows( dataset.identifier, limit=10 )