diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index bd12222a8..f7960185e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -270,18 +270,67 @@ class ChatAgent(ShieldRunnerMixin): messages.extend(request.tool_responses) - # steps = [] - # output_message = None - # async for chunk in self.run( - # session_id=request.session_id, - # turn_id=request.turn_id, - # input_messages=messages, - # sampling_params=self.agent_config.sampling_params, - # stream=request.stream, - # documents=request.documents, - # toolgroups_for_turn=request.toolgroups, - # ): - # if isinstance(chunk, CompletionMessage): + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_messages = [] + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + last_turn_messages = self.turn_to_messages(turns[-1]) + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk async def run( self,