From 6deef1ece09d4b4d60232a6c40737ad034b02c73 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 15 Jan 2025 12:55:19 -0800 Subject: [PATCH] rebase eval test w/ tool_runtime fixtures (#773) # What does this PR do? - fix eval tests to include tool_runtime fixtures - rebase eval for extracting memory retrieval context ## Test Plan ``` pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` - With notebook: https://gist.github.com/yanxi0830/1260a6cb7ec42498a195b88422462a34 ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../providers/inline/eval/meta_reference/eval.py | 11 +++++++++-- llama_stack/providers/tests/eval/conftest.py | 5 +++++ llama_stack/providers/tests/eval/fixtures.py | 11 ++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 408043db8..63c1e8d98 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -16,6 +16,9 @@ from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import EvalTasksProtocolPrivate +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( + MEMORY_QUERY_TOOL, +) from llama_stack.providers.utils.common.data_schema_validator import ( ColumnName, get_valid_schemas, @@ -146,8 +149,12 @@ class MetaReferenceEvalImpl( # check if there's a memory retrieval step and extract the context memory_rag_context = None for step in final_event.turn.steps: - if step.step_type == StepType.memory_retrieval.value: - memory_rag_context = " ".join(x.text for x in step.inserted_context) + if step.step_type == StepType.tool_execution.value: + for tool_response in step.tool_responses: + if tool_response.tool_name == MEMORY_QUERY_TOOL: + memory_rag_context = " ".join( + x.text for x in tool_response.content + ) agent_generation = {} agent_generation[ColumnName.generated_answer.value] = ( diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index 1bb49d41f..3d6ef01b2 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -15,6 +15,7 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES from ..scoring.fixtures import SCORING_FIXTURES +from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import EVAL_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ @@ -27,6 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_fireworks_inference", marks=pytest.mark.meta_reference_eval_fireworks_inference, @@ -40,6 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference", marks=pytest.mark.meta_reference_eval_together_inference, @@ -53,6 +56,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "agents": "meta_reference", "safety": "llama_guard", "memory": "faiss", + "tool_runtime": "memory_and_search", }, id="meta_reference_eval_together_inference_huggingface_datasetio", marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio, @@ -98,6 +102,7 @@ def pytest_generate_tests(metafunc): "agents": AGENTS_FIXTURES, "safety": SAFETY_FIXTURES, "memory": MEMORY_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( get_provider_fixture_overrides(metafunc.config, available_fixtures) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index eba7c48a6..37bb0527a 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -35,7 +35,13 @@ EVAL_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def eval_stack(request, inference_model, judge_model): +async def eval_stack( + request, + inference_model, + judge_model, + tool_group_input_memory, + tool_group_input_tavily_search, +): fixture_dict = request.param providers = {} @@ -48,6 +54,7 @@ async def eval_stack(request, inference_model, judge_model): "agents", "safety", "memory", + "tool_runtime", ]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers @@ -63,6 +70,7 @@ async def eval_stack(request, inference_model, judge_model): Api.agents, Api.safety, Api.memory, + Api.tool_runtime, ], providers, provider_data, @@ -73,6 +81,7 @@ async def eval_stack(request, inference_model, judge_model): judge_model, ] ], + tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], ) return test_stack.impls