continue turn

This commit is contained in:
Xi Yan 2025-02-20 18:00:57 -08:00
parent 22355e3b1f
commit 4923270122

View file

@ -270,18 +270,67 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(request.tool_responses)
# steps = []
# output_message = None
# async for chunk in self.run(
# session_id=request.session_id,
# turn_id=request.turn_id,
# input_messages=messages,
# sampling_params=self.agent_config.sampling_params,
# stream=request.stream,
# documents=request.documents,
# toolgroups_for_turn=request.toolgroups,
# ):
# if isinstance(chunk, CompletionMessage):
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_messages = []
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
last_turn_messages = self.turn_to_messages(turns[-1])
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
self,