diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index e1af29e12..d257cb8e7 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from tqdm import tqdm -from llama_stack.apis.agents import Agents +from llama_stack.apis.agents import Agents, StepType from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask @@ -135,11 +135,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM ) ] final_event = turn_response[-1].event.payload - generations.append( - { - ColumnName.generated_answer.value: final_event.turn.output_message.content - } + + # check if there's a memory retrieval step and extract the context + memory_rag_context = None + for step in final_event.turn.steps: + if step.step_type == StepType.memory_retrieval.value: + memory_rag_context = " ".join(x.text for x in step.inserted_context) + + agent_generation = {} + agent_generation[ColumnName.generated_answer.value] = ( + final_event.turn.output_message.content ) + if memory_rag_context: + agent_generation[ColumnName.context.value] = memory_rag_context + + generations.append(agent_generation) return generations