From b7e59ba002d1bbf4a526ce47435b68fc1df5bbaf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 15 Jan 2025 11:45:59 -0800 Subject: [PATCH] fix eval test w/ tools --- llama_stack/providers/tests/eval/conftest.py | 5 +++++ llama_stack/providers/tests/eval/fixtures.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index 1bb49d41f..3d6ef01b2 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -15,6 +15,7 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES from ..scoring.fixtures import SCORING_FIXTURES +from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import EVAL_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ @@ -27,6 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_fireworks_inference", marks=pytest.mark.meta_reference_eval_fireworks_inference, @@ -40,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference", marks=pytest.mark.meta_reference_eval_together_inference, @@ -53,6 +56,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference_huggingface_datasetio", marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio, @@ -98,6 +102,7 @@ def pytest_generate_tests(metafunc): "agents": AGENTS_FIXTURES, "safety": SAFETY_FIXTURES, "memory": MEMORY_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( get_provider_fixture_overrides(metafunc.config, available_fixtures) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index eba7c48a6..37bb0527a 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -35,7 +35,13 @@ EVAL_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def eval_stack(request, inference_model, judge_model): +async def eval_stack( + request, + inference_model, + judge_model, + tool_group_input_memory, + tool_group_input_tavily_search, +): fixture_dict = request.param providers = {} @@ -48,6 +54,7 @@ async def eval_stack(request, inference_model, judge_model): "agents", "safety", "memory", + "tool_runtime", ]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers @@ -63,6 +70,7 @@ async def eval_stack(request, inference_model, judge_model): Api.agents, Api.safety, Api.memory, + Api.tool_runtime, ], providers, provider_data, @@ -73,6 +81,7 @@ async def eval_stack(request, inference_model, judge_model): judge_model, ] ], + tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], ) return test_stack.impls