update resume

This commit is contained in:
Xi Yan 2025-03-04 13:52:11 -08:00
parent e8f7fa1ce1
commit 025ab9cd01

View file

@ -185,15 +185,25 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
assert request.stream is True, "Non-streaming not supported"
async for chunk in self._run_turn(request, turn_id): async for chunk in self._run_turn(request, turn_id):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
async for chunk in self._run_turn(request):
yield chunk
async def _run_turn( async def _run_turn(
self, self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None, turn_id: Optional[str] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
is_resume = isinstance(request, AgentTurnResumeRequest) is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id) session_info = await self.storage.get_session_info(request.session_id)
if session_info is None: if session_info is None:
@ -203,99 +213,19 @@ class ChatAgent(ShieldRunnerMixin):
if is_resume and len(turns) == 0: if is_resume and len(turns) == 0:
raise ValueError("No turns found for session") raise ValueError("No turns found for session")
steps = []
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
messages.extend(request.tool_responses) messages.extend(request.tool_responses)
turn_id = request.turn_id
start_time = turns[-1].started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups if not is_resume else None,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
if len(turns) == 0:
raise ValueError("No turns found for session")
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [ last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
] ]
# TODO: figure out whether we should add the tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses) last_turn_messages.extend(request.tool_responses)
# get the steps from the turn id # get steps from the turn
steps = [] steps = last_turn.steps
steps = turns[-1].steps
# mark tool execution step as complete # mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), # if there's no tool execution in progress step (due to storage, or tool call parsing on client),
@ -329,14 +259,24 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
input_messages = last_turn_messages
turn_id = request.turn_id
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
input_messages = request.messages
output_message = None output_message = None
async for chunk in self.run( async for chunk in self.run(
session_id=request.session_id, session_id=request.session_id,
turn_id=request.turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
sampling_params=self.agent_config.sampling_params, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups if not is_resume else None,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
output_message = chunk output_message = chunk
@ -351,21 +291,16 @@ class ChatAgent(ShieldRunnerMixin):
assert output_message is not None assert output_message is not None
last_turn_start_time = datetime.now().astimezone().isoformat()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn( turn = Turn(
turn_id=request.turn_id, turn_id=turn_id,
session_id=request.session_id, session_id=request.session_id,
input_messages=last_turn_messages, input_messages=input_messages,
output_message=output_message, output_message=output_message,
started_at=last_turn_start_time, started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(), completed_at=datetime.now().astimezone().isoformat(),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls: if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(