From 97d6b87e05b0cade3db7e3fa25b60deaf82476b4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 16:47:22 -0700 Subject: [PATCH] datasetio --- tests/integration/datasetio/test_datasetio.py | 17 +++++++++-- tests/integration/scoring/test_scoring.py | 28 +++++++++---------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/integration/datasetio/test_datasetio.py b/tests/integration/datasetio/test_datasetio.py index 5b1d1a37a..6dddf5915 100644 --- a/tests/integration/datasetio/test_datasetio.py +++ b/tests/integration/datasetio/test_datasetio.py @@ -9,10 +9,23 @@ import mimetypes import os from pathlib import Path +import pytest + # How to run this test: # # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio +@pytest.fixture +def test_dataset(llama_stack_client): + register_dataset(llama_stack_client) + yield # This is where the test function will run + + # Teardown - this always runs, even if the test fails + try: + llama_stack_client.datasets.unregister("test_dataset") + except Exception as e: + print(f"Warning: Failed to unregister test_dataset: {e}") + def data_url_from_file(file_path: str) -> str: if not os.path.exists(file_path): @@ -80,8 +93,7 @@ def test_register_unregister_dataset(llama_stack_client): assert len(response) == 0 -def test_get_rows_paginated(llama_stack_client): - register_dataset(llama_stack_client) +def test_get_rows_paginated(llama_stack_client, test_dataset): response = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", rows_in_page=3, @@ -99,4 +111,3 @@ def test_get_rows_paginated(llama_stack_client): assert isinstance(response.rows, list) assert len(response.rows) == 2 assert response.next_page_token == "5" - llama_stack_client.datasets.unregister("test_dataset") diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py index 6c2623705..3477516a2 100644 --- a/tests/integration/scoring/test_scoring.py +++ b/tests/integration/scoring/test_scoring.py @@ -9,6 +9,17 @@ import pytest from ..datasetio.test_datasetio import register_dataset +@pytest.fixture +def test_dataset_rag(llama_stack_client): + register_dataset(llama_stack_client, for_rag=True) + yield # This is where the test function will run + + # Teardown - this always runs, even if the test fails + try: + llama_stack_client.datasets.unregister("test_dataset") + except Exception as e: + print(f"Warning: Failed to unregister test_dataset: {e}") + @pytest.fixture def sample_judge_prompt_template(): @@ -79,9 +90,7 @@ def test_scoring_functions_register( # TODO: add unregister api for scoring functions -def test_scoring_score(llama_stack_client): - register_dataset(llama_stack_client, for_rag=True) - +def test_scoring_score(llama_stack_client, test_dataset_rag): # scoring individual rows rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", @@ -114,12 +123,8 @@ def test_scoring_score(llama_stack_client): assert x in response.results assert len(response.results[x].score_rows) == 5 - llama_stack_client.datasets.unregister("test_dataset") - - -def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id): - register_dataset(llama_stack_client, for_rag=True) +def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id, test_dataset_rag): # scoring individual rows rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", @@ -159,8 +164,6 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge assert x in response.results assert len(response.results[x].score_rows) == 5 - llama_stack_client.datasets.unregister("test_dataset") - @pytest.mark.parametrize( "provider_id", @@ -171,9 +174,8 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge ], ) def test_scoring_score_with_aggregation_functions( - llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id + llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id, test_dataset_rag ): - register_dataset(llama_stack_client, for_rag=True) rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", rows_in_page=3, @@ -227,5 +229,3 @@ def test_scoring_score_with_aggregation_functions( assert x in response.results assert len(response.results[x].score_rows) == len(rows.rows) assert len(response.results[x].aggregated_results) == len(aggr_fns) - - llama_stack_client.datasets.unregister("test_dataset")