diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 408043db8..63c1e8d98 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -16,6 +16,9 @@ from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import EvalTasksProtocolPrivate +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( + MEMORY_QUERY_TOOL, +) from llama_stack.providers.utils.common.data_schema_validator import ( ColumnName, get_valid_schemas, @@ -146,8 +149,12 @@ class MetaReferenceEvalImpl( # check if there's a memory retrieval step and extract the context memory_rag_context = None for step in final_event.turn.steps: - if step.step_type == StepType.memory_retrieval.value: - memory_rag_context = " ".join(x.text for x in step.inserted_context) + if step.step_type == StepType.tool_execution.value: + for tool_response in step.tool_responses: + if tool_response.tool_name == MEMORY_QUERY_TOOL: + memory_rag_context = " ".join( + x.text for x in tool_response.content + ) agent_generation = {} agent_generation[ColumnName.generated_answer.value] = ( diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index 1bb49d41f..3d6ef01b2 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -15,6 +15,7 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES from ..scoring.fixtures import SCORING_FIXTURES +from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import EVAL_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ @@ -27,6 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_fireworks_inference", marks=pytest.mark.meta_reference_eval_fireworks_inference, @@ -40,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference", marks=pytest.mark.meta_reference_eval_together_inference, @@ -53,6 +56,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference_huggingface_datasetio", marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio, @@ -98,6 +102,7 @@ def pytest_generate_tests(metafunc): "agents": AGENTS_FIXTURES, "safety": SAFETY_FIXTURES, "memory": MEMORY_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( get_provider_fixture_overrides(metafunc.config, available_fixtures) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index eba7c48a6..37bb0527a 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -35,7 +35,13 @@ EVAL_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def eval_stack(request, inference_model, judge_model): +async def eval_stack( + request, + inference_model, + judge_model, + tool_group_input_memory, + tool_group_input_tavily_search, +): fixture_dict = request.param providers = {} @@ -48,6 +54,7 @@ async def eval_stack(request, inference_model, judge_model): "agents", "safety", "memory", + "tool_runtime", ]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers @@ -63,6 +70,7 @@ async def eval_stack(request, inference_model, judge_model): Api.agents, Api.safety, Api.memory, + Api.tool_runtime, ], providers, provider_data, @@ -73,6 +81,7 @@ async def eval_stack(request, inference_model, judge_model): judge_model, ] ], + tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], ) return test_stack.impls