diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index f6bde8927..61249ad17 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -10,8 +10,7 @@ from uuid import uuid4 import pytest from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument -from llama_stack_client.types.memory_insert_params import Document +from llama_stack_client.types.agents.turn_create_params import Document from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack.apis.agents.agents import ( @@ -242,7 +241,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) session_id = codex_agent.create_session(f"test-session-{uuid4()}") - inflation_doc = AgentDocument( + inflation_doc = Document( content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", mime_type="text/csv", ) diff --git a/tests/integration/datasetio/test_datasetio.py b/tests/integration/datasetio/test_datasetio.py index f112071a6..459589e7b 100644 --- a/tests/integration/datasetio/test_datasetio.py +++ b/tests/integration/datasetio/test_datasetio.py @@ -9,11 +9,25 @@ import mimetypes import os from pathlib import Path +import pytest + # How to run this test: # # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio +@pytest.fixture +def dataset_for_test(llama_stack_client): + dataset_id = "test_dataset" + register_dataset(llama_stack_client, dataset_id=dataset_id) + yield + # Teardown - this always runs, even if the test fails + try: + llama_stack_client.datasets.unregister(dataset_id) + except Exception as e: + print(f"Warning: Failed to unregister test_dataset: {e}") + + def data_url_from_file(file_path: str) -> str: if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") @@ -80,8 +94,7 @@ def test_register_unregister_dataset(llama_stack_client): assert len(response) == 0 -def test_get_rows_paginated(llama_stack_client): - register_dataset(llama_stack_client) +def test_get_rows_paginated(llama_stack_client, dataset_for_test): response = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", rows_in_page=3, diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py index 2fcdf54e2..970a96f40 100644 --- a/tests/integration/scoring/test_scoring.py +++ b/tests/integration/scoring/test_scoring.py @@ -10,6 +10,19 @@ import pytest from ..datasetio.test_datasetio import register_dataset +@pytest.fixture +def rag_dataset_for_test(llama_stack_client): + dataset_id = "test_dataset" + register_dataset(llama_stack_client, for_rag=True, dataset_id=dataset_id) + yield # This is where the test function will run + + # Teardown - this always runs, even if the test fails + try: + llama_stack_client.datasets.unregister(dataset_id) + except Exception as e: + print(f"Warning: Failed to unregister test_dataset: {e}") + + @pytest.fixture def sample_judge_prompt_template(): return "Output a number response in the following format: Score: , where is the number between 0 and 9." @@ -79,9 +92,7 @@ def test_scoring_functions_register( # TODO: add unregister api for scoring functions -def test_scoring_score(llama_stack_client): - register_dataset(llama_stack_client, for_rag=True) - +def test_scoring_score(llama_stack_client, rag_dataset_for_test): # scoring individual rows rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", @@ -115,9 +126,9 @@ def test_scoring_score(llama_stack_client): assert len(response.results[x].score_rows) == 5 -def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id): - register_dataset(llama_stack_client, for_rag=True) - +def test_scoring_score_with_params_llm_as_judge( + llama_stack_client, sample_judge_prompt_template, judge_model_id, rag_dataset_for_test +): # scoring individual rows rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", @@ -167,9 +178,8 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge ], ) def test_scoring_score_with_aggregation_functions( - llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id + llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id, rag_dataset_for_test ): - register_dataset(llama_stack_client, for_rag=True) rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", rows_in_page=3,