diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index 0715be06e..be02705bc 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -16,16 +16,26 @@ import pytest # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets -def test_register_dataset(llama_stack_client): +@pytest.mark.parametrize( + "purpose, source, provider_id", + [ + ( + "eval/messages-answer", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + "huggingface", + ), + ], +) +def test_register_dataset(llama_stack_client, purpose, source, provider_id): dataset = llama_stack_client.datasets.register( - purpose="eval/messages-answer", - source={ - "type": "uri", - "uri": "huggingface://datasets/llamastack/simpleqa?split=train", - }, + purpose=purpose, + source=source, ) assert dataset.identifier is not None - assert dataset.provider_id == "huggingface" + assert dataset.provider_id == provider_id iterrow_response = llama_stack_client.datasets.iterrows( dataset.identifier, limit=10 )