diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f7960185e..b0d822511 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -275,6 +275,36 @@ class ChatAgent(ShieldRunnerMixin): if len(turns) > 0: steps = turns[-1].steps + # mark tool execution step as complete + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in in_progress_tool_call_step.tool_responses + ], + completed_at=datetime.now(), + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, + ) + ) + ) + output_message = None async for chunk in self.run( session_id=request.session_id, @@ -302,6 +332,14 @@ class ChatAgent(ShieldRunnerMixin): last_turn_start_time = turns[-1].started_at last_turn_messages = self.turn_to_messages(turns[-1]) + # add tool responses to the last turn messages + last_turn_messages.extend(request.tool_responses) + # filter out non User / Tool messages + # TODO: should we just keep all message types in Turn.input_messages? + last_turn_messages = [ + m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage) + ] + turn = Turn( turn_id=request.turn_id, session_id=request.session_id, @@ -739,6 +777,17 @@ class ChatAgent(ShieldRunnerMixin): # If tool is a client tool, yield CompletionMessage and return if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) yield message return diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index bdde89016..35038b339 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -178,6 +178,13 @@ class MetaReferenceAgentsImpl(Agents): tool_responses: List[ToolResponseMessage], stream: Optional[bool] = False, ) -> AsyncGenerator: + request = AgentTurnContinueRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) if stream: return self._continue_agent_turn_streaming(request) else: diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4a..3c3866873 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from typing import List, Optional from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ class AgentPersistence: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None