diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e14a35463..3502c21f2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -125,13 +125,25 @@ class ChatAgent(ShieldRunnerMixin): def turn_to_messages(self, turn: Turn) -> List[Message]: messages = [] - # We do not want to keep adding RAG context to the input messages - # May be this should be a parameter of the agentic instance - # that can define its behavior in a custom way + # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages + tool_call_ids = set() + for step in turn.steps: + if step.step_type == StepType.tool_execution.value: + for response in step.tool_responses: + tool_call_ids.add(response.call_id) + for m in turn.input_messages: msg = m.model_copy() + # We do not want to keep adding RAG context to the input messages + # May be this should be a parameter of the agentic instance + # that can define its behavior in a custom way if isinstance(msg, UserMessage): msg.context = None + if isinstance(msg, ToolResponseMessage): + if msg.call_id in tool_call_ids: + # NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps + continue + messages.append(msg) for step in turn.steps: @@ -265,17 +277,24 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) + if len(turns) == 0: + raise ValueError("No turns found for session") + messages = await self.get_messages_from_turns(turns) messages.extend(request.tool_responses) + last_turn = turns[-1] + last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ - x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] + # TODO: figure out whether we should add the tool responses to the last turn messages + last_turn_messages.extend(request.tool_responses) + # get the steps from the turn id steps = [] - if len(turns) > 0: - steps = turns[-1].steps + steps = turns[-1].steps # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client),