mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
continue turn
This commit is contained in:
parent
22355e3b1f
commit
4923270122
1 changed files with 61 additions and 12 deletions
|
@ -270,18 +270,67 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
# steps = []
|
# get the steps from the turn id
|
||||||
# output_message = None
|
steps = []
|
||||||
# async for chunk in self.run(
|
if len(turns) > 0:
|
||||||
# session_id=request.session_id,
|
steps = turns[-1].steps
|
||||||
# turn_id=request.turn_id,
|
|
||||||
# input_messages=messages,
|
output_message = None
|
||||||
# sampling_params=self.agent_config.sampling_params,
|
async for chunk in self.run(
|
||||||
# stream=request.stream,
|
session_id=request.session_id,
|
||||||
# documents=request.documents,
|
turn_id=request.turn_id,
|
||||||
# toolgroups_for_turn=request.toolgroups,
|
input_messages=messages,
|
||||||
# ):
|
sampling_params=self.agent_config.sampling_params,
|
||||||
# if isinstance(chunk, CompletionMessage):
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_messages = []
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
last_turn_messages = self.turn_to_messages(turns[-1])
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue