From 663c6b05379bded20a07379d61cc00490e0c5542 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 15:06:47 -0800 Subject: [PATCH] fix: duplicate ToolResponseMessage in Turn message history (#1305) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Reproduce with: https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py - **Root cause**: when we have ToolResponseMessage as part of Turn, we will create duplicate ToolResponseMessage in the conversation history when getting messages from a Turn. - Fix: avoid adding duplicate ToolResponseMessage from a turn's input_messages. - If it is part of a Turn's steps, only add it when processing the steps. - If it is not part of a Turn's steps, add it. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` ``` python -m examples.agents.e2e_loop_with_client_tools localhost 8321 ``` ```python Turn( │ input_messages=[ │ │ UserMessage( │ │ │ content='What was the closing price of Google stock (ticker symbol GOOG) for 2023 ?', │ │ │ role='user', │ │ │ context=None │ │ ), │ │ ToolResponseMessage( │ │ │ call_id='0d5f94fb-f070-4dc1-8eeb-63eb5918ec94', │ │ │ content='"[{\\"(\'Year\', \'\')\\":2023,\\"(\'Close\', \'GOOG\')\\":140.4254302979}]"', │ │ │ role='tool', │ │ │ tool_name='get_ticker_data' │ │ ) │ ], │ output_message=CompletionMessage( │ │ content='Note: The actual closing price for 2023 may not be available or may be different from the result obtained above. The result is based on a hypothetical call to the get_ticker_data function.', │ │ role='assistant', │ │ stop_reason='end_of_turn', │ │ tool_calls=[] │ ), │ session_id='4c791107-f0d8-456e-a27f-aa2fdc72b871', │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 25, 412928, tzinfo=TzInfo(-08:00)), │ steps=[ │ │ ShieldCallStep( │ │ │ step_id='e0514587-b7d6-4bba-8609-8e05a3a46d8a', │ │ │ step_type='shield_call', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 25, 858382, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 25, 425204, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content='', │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[ │ │ │ │ │ ToolCall( │ │ │ │ │ │ arguments={ │ │ │ │ │ │ │ 'ticker_symbol': 'GOOG', │ │ │ │ │ │ │ 'start': '2023-01-01', │ │ │ │ │ │ │ 'end': '2023-12-31' │ │ │ │ │ │ }, │ │ │ │ │ │ call_id='0d5f94fb-f070-4dc1-8eeb-63eb5918ec94', │ │ │ │ │ │ tool_name='get_ticker_data' │ │ │ │ │ ) │ │ │ │ ] │ │ │ ), │ │ │ step_id='a3ceec6a-f149-49d5-a1c2-db461e3f6e9f', │ │ │ step_type='inference', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 26, 910179, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 25, 871130, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='f9339865-96ca-4425-af42-a87bab343e24', │ │ │ step_type='shield_call', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 28, 383013, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 26, 944012, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ ToolExecutionStep( │ │ │ step_id='e317b74a-c4f3-4845-99a3-7d93aa6ea6c8', │ │ │ step_type='tool_execution', │ │ │ tool_calls=[ │ │ │ │ ToolCall( │ │ │ │ │ arguments={'ticker_symbol': 'GOOG', 'start': '2023-01-01', 'end': '2023-12-31'}, │ │ │ │ │ call_id='0d5f94fb-f070-4dc1-8eeb-63eb5918ec94', │ │ │ │ │ tool_name='get_ticker_data' │ │ │ │ ) │ │ │ ], │ │ │ tool_responses=[ │ │ │ │ ToolResponse( │ │ │ │ │ call_id='0d5f94fb-f070-4dc1-8eeb-63eb5918ec94', │ │ │ │ │ content='"[{\\"(\'Year\', \'\')\\":2023,\\"(\'Close\', \'GOOG\')\\":140.4254302979}]"', │ │ │ │ │ tool_name='get_ticker_data', │ │ │ │ │ metadata=None │ │ │ │ ) │ │ │ ], │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 28, 718810, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 26, 943792, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='c4236616-db89-4c04-ad04-f51cfb726385', │ │ │ step_type='shield_call', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 28, 958946, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 28, 732680, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content='Note: The actual closing price for 2023 may not be available or may be different from the result obtained above. The result is based on a hypothetical call to the get_ticker_data function.', │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[] │ │ │ ), │ │ │ step_id='3386f896-2026-41e4-a60f-f6f3c3981cf6', │ │ │ step_type='inference', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 37, 74750, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 28, 970724, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='bc57ac8c-f94e-4758-bf1a-0dd734eca1cf', │ │ │ step_type='shield_call', │ │ │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 37, 443016, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 13, 59, 37, 86726, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ) │ ], │ turn_id='6ed9c25a-a4fe-4b51-ae13-de248624c2fc', │ completed_at=datetime.datetime(2025, 2, 27, 13, 59, 37, 459456, tzinfo=TzInfo(-08:00)), │ output_attachments=[] ) ``` ```python Turn( │ input_messages=[ │ │ UserMessage(content='What is 40+30?', role='user', context=None), │ │ ToolResponseMessage( │ │ │ call_id='8e54aca9-244d-44ca-ada0-0365090e8622', │ │ │ content='{"success": true, "result": 70.0}', │ │ │ role='tool', │ │ │ tool_name='calculator' │ │ ) │ ], │ output_message=CompletionMessage( │ │ content='The result of the calculation is 70.', │ │ role='assistant', │ │ stop_reason='end_of_turn', │ │ tool_calls=[] │ ), │ session_id='4c791107-f0d8-456e-a27f-aa2fdc72b871', │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 0, 156903, tzinfo=TzInfo(-08:00)), │ steps=[ │ │ ShieldCallStep( │ │ │ step_id='17b6b645-31cc-4be9-a758-a4f3b741ced9', │ │ │ step_type='shield_call', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 0, 780564, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 0, 174515, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content='', │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[ │ │ │ │ │ ToolCall( │ │ │ │ │ │ arguments={'x': 40.0, 'y': 30.0, 'operation': 'add'}, │ │ │ │ │ │ call_id='8e54aca9-244d-44ca-ada0-0365090e8622', │ │ │ │ │ │ tool_name='calculator' │ │ │ │ │ ) │ │ │ │ ] │ │ │ ), │ │ │ step_id='f59e951a-2b75-497d-a075-ec9aad9aad12', │ │ │ step_type='inference', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 141869, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 0, 792047, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='efafa0cf-23b9-4a90-8350-3a186d80925d', │ │ │ step_type='shield_call', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 766293, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 177473, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ ToolExecutionStep( │ │ │ step_id='877cfbe7-57a8-4056-9c29-49aa38dd337c', │ │ │ step_type='tool_execution', │ │ │ tool_calls=[ │ │ │ │ ToolCall( │ │ │ │ │ arguments={'x': 40.0, 'y': 30.0, 'operation': 'add'}, │ │ │ │ │ call_id='8e54aca9-244d-44ca-ada0-0365090e8622', │ │ │ │ │ tool_name='calculator' │ │ │ │ ) │ │ │ ], │ │ │ tool_responses=[ │ │ │ │ ToolResponse( │ │ │ │ │ call_id='8e54aca9-244d-44ca-ada0-0365090e8622', │ │ │ │ │ content='{"success": true, "result": 70.0}', │ │ │ │ │ tool_name='calculator', │ │ │ │ │ metadata=None │ │ │ │ ) │ │ │ ], │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 930899, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 177202, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='d47c6160-45d9-47c1-8e39-2faae65ee468', │ │ │ step_type='shield_call', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 3, 510402, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 2, 949433, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ), │ │ InferenceStep( │ │ │ api_model_response=CompletionMessage( │ │ │ │ content='The result of the calculation is 70.', │ │ │ │ role='assistant', │ │ │ │ stop_reason='end_of_turn', │ │ │ │ tool_calls=[] │ │ │ ), │ │ │ step_id='660ba1cc-770e-471c-bf6e-11e103d74443', │ │ │ step_type='inference', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 4, 814944, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 3, 521309, tzinfo=TzInfo(-08:00)) │ │ ), │ │ ShieldCallStep( │ │ │ step_id='4dab8bb0-7d38-4465-ae1a-10069de2b3d1', │ │ │ step_type='shield_call', │ │ │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ │ │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 5, 428561, tzinfo=TzInfo(-08:00)), │ │ │ started_at=datetime.datetime(2025, 2, 27, 14, 0, 4, 825970, tzinfo=TzInfo(-08:00)), │ │ │ violation=None │ │ ) │ ], │ turn_id='4daff286-f703-417e-a5dc-0e158582bbec', │ completed_at=datetime.datetime(2025, 2, 27, 14, 0, 5, 462823, tzinfo=TzInfo(-08:00)), │ output_attachments=[] ) ``` [//]: # (## Documentation) --- .../agents/meta_reference/agent_instance.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e14a35463..3502c21f2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -125,13 +125,25 @@ class ChatAgent(ShieldRunnerMixin): def turn_to_messages(self, turn: Turn) -> List[Message]: messages = [] - # We do not want to keep adding RAG context to the input messages - # May be this should be a parameter of the agentic instance - # that can define its behavior in a custom way + # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages + tool_call_ids = set() + for step in turn.steps: + if step.step_type == StepType.tool_execution.value: + for response in step.tool_responses: + tool_call_ids.add(response.call_id) + for m in turn.input_messages: msg = m.model_copy() + # We do not want to keep adding RAG context to the input messages + # May be this should be a parameter of the agentic instance + # that can define its behavior in a custom way if isinstance(msg, UserMessage): msg.context = None + if isinstance(msg, ToolResponseMessage): + if msg.call_id in tool_call_ids: + # NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps + continue + messages.append(msg) for step in turn.steps: @@ -265,17 +277,24 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) + if len(turns) == 0: + raise ValueError("No turns found for session") + messages = await self.get_messages_from_turns(turns) messages.extend(request.tool_responses) + last_turn = turns[-1] + last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ - x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] + # TODO: figure out whether we should add the tool responses to the last turn messages + last_turn_messages.extend(request.tool_responses) + # get the steps from the turn id steps = [] - if len(turns) > 0: - steps = turns[-1].steps + steps = turns[-1].steps # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client),