From 025ab9cd01036faf43217863b704e96722a5f645 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Mar 2025 13:52:11 -0800 Subject: [PATCH] update resume --- .../agents/meta_reference/agent_instance.py | 203 ++++++------------ 1 file changed, 69 insertions(+), 134 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e4409d6a7..7cd4c868d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -185,15 +185,25 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("request", request.model_dump_json()) turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) - assert request.stream is True, "Non-streaming not supported" async for chunk in self._run_turn(request, turn_id): yield chunk + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + async for chunk in self._run_turn(request): + yield chunk + async def _run_turn( self, request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], turn_id: Optional[str] = None, ) -> AsyncGenerator: + assert request.stream is True, "Non-streaming not supported" + is_resume = isinstance(request, AgentTurnResumeRequest) session_info = await self.storage.get_session_info(request.session_id) if session_info is None: @@ -203,99 +213,19 @@ class ChatAgent(ShieldRunnerMixin): if is_resume and len(turns) == 0: raise ValueError("No turns found for session") + steps = [] messages = await self.get_messages_from_turns(turns) if is_resume: messages.extend(request.tool_responses) - turn_id = request.turn_id - start_time = turns[-1].started_at - else: - messages.extend(request.messages) - start_time = datetime.now().astimezone().isoformat() - - steps = [] - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=turn_id, - input_messages=messages, - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - documents=request.documents if not is_resume else None, - toolgroups_for_turn=request.toolgroups if not is_resume else None, - ): - if isinstance(chunk, CompletionMessage): - output_message = chunk - continue - - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" - event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) - - yield chunk - - assert output_message is not None - - turn = Turn( - turn_id=turn_id, - session_id=request.session_id, - input_messages=request.messages, - output_message=output_message, - started_at=start_time, - completed_at=datetime.now().astimezone().isoformat(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - turn=turn, - ) - ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, - ) - ) - ) - - yield chunk - - async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: - with tracing.span("resume_turn") as span: - span.set_attribute("agent_id", self.agent_id) - span.set_attribute("session_id", request.session_id) - span.set_attribute("turn_id", request.turn_id) - span.set_attribute("request", request.model_dump_json()) - assert request.stream is True, "Non-streaming not supported" - - session_info = await self.storage.get_session_info(request.session_id) - if session_info is None: - raise ValueError(f"Session {request.session_id} not found") - - turns = await self.storage.get_session_turns(request.session_id) - if len(turns) == 0: - raise ValueError("No turns found for session") - - messages = await self.get_messages_from_turns(turns) - messages.extend(request.tool_responses) - last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = [ x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) ] - - # TODO: figure out whether we should add the tool responses to the last turn messages last_turn_messages.extend(request.tool_responses) - # get the steps from the turn id - steps = [] - steps = turns[-1].steps + # get steps from the turn + steps = last_turn.steps # mark tool execution step as complete # if there's no tool execution in progress step (due to storage, or tool call parsing on client), @@ -329,62 +259,67 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + input_messages = last_turn_messages - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=request.turn_id, - input_messages=messages, - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - ): - if isinstance(chunk, CompletionMessage): - output_message = chunk - continue + turn_id = request.turn_id + start_time = last_turn.started_at + else: + messages.extend(request.messages) + start_time = datetime.now().astimezone().isoformat() + input_messages = request.messages - assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" - event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + documents=request.documents if not is_resume else None, + toolgroups_for_turn=request.toolgroups if not is_resume else None, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue - yield chunk - - assert output_message is not None - - last_turn_start_time = datetime.now().astimezone().isoformat() - if len(turns) > 0: - last_turn_start_time = turns[-1].started_at - - turn = Turn( - turn_id=request.turn_id, - session_id=request.session_id, - input_messages=last_turn_messages, - output_message=output_message, - started_at=last_turn_start_time, - completed_at=datetime.now().astimezone().isoformat(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) - - if output_message.tool_calls: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - turn=turn, - ) - ) - ) - else: - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, - ) - ) - ) + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) yield chunk + assert output_message is not None + + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=input_messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now().astimezone().isoformat(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + async def run( self, session_id: str,