mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
update resume
This commit is contained in:
parent
e8f7fa1ce1
commit
025ab9cd01
1 changed files with 69 additions and 134 deletions
|
@ -185,15 +185,25 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
assert request.stream is True, "Non-streaming not supported"
|
|
||||||
async for chunk in self._run_turn(request, turn_id):
|
async for chunk in self._run_turn(request, turn_id):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
async for chunk in self._run_turn(request):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def _run_turn(
|
async def _run_turn(
|
||||||
self,
|
self,
|
||||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||||
turn_id: Optional[str] = None,
|
turn_id: Optional[str] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||||
session_info = await self.storage.get_session_info(request.session_id)
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
|
@ -203,99 +213,19 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if is_resume and len(turns) == 0:
|
if is_resume and len(turns) == 0:
|
||||||
raise ValueError("No turns found for session")
|
raise ValueError("No turns found for session")
|
||||||
|
|
||||||
|
steps = []
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
if is_resume:
|
if is_resume:
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
turn_id = request.turn_id
|
|
||||||
start_time = turns[-1].started_at
|
|
||||||
else:
|
|
||||||
messages.extend(request.messages)
|
|
||||||
start_time = datetime.now().astimezone().isoformat()
|
|
||||||
|
|
||||||
steps = []
|
|
||||||
output_message = None
|
|
||||||
async for chunk in self.run(
|
|
||||||
session_id=request.session_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
input_messages=messages,
|
|
||||||
sampling_params=self.agent_config.sampling_params,
|
|
||||||
stream=request.stream,
|
|
||||||
documents=request.documents if not is_resume else None,
|
|
||||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
|
||||||
):
|
|
||||||
if isinstance(chunk, CompletionMessage):
|
|
||||||
output_message = chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
|
||||||
event = chunk.event
|
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
|
||||||
steps.append(event.payload.step_details)
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
assert output_message is not None
|
|
||||||
|
|
||||||
turn = Turn(
|
|
||||||
turn_id=turn_id,
|
|
||||||
session_id=request.session_id,
|
|
||||||
input_messages=request.messages,
|
|
||||||
output_message=output_message,
|
|
||||||
started_at=start_time,
|
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
|
||||||
steps=steps,
|
|
||||||
)
|
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
|
||||||
if output_message.tool_calls:
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
|
||||||
turn=turn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
|
||||||
turn=turn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
|
||||||
with tracing.span("resume_turn") as span:
|
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
|
||||||
span.set_attribute("session_id", request.session_id)
|
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
|
||||||
span.set_attribute("request", request.model_dump_json())
|
|
||||||
assert request.stream is True, "Non-streaming not supported"
|
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(request.session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
|
||||||
if len(turns) == 0:
|
|
||||||
raise ValueError("No turns found for session")
|
|
||||||
|
|
||||||
messages = await self.get_messages_from_turns(turns)
|
|
||||||
messages.extend(request.tool_responses)
|
|
||||||
|
|
||||||
last_turn = turns[-1]
|
last_turn = turns[-1]
|
||||||
last_turn_messages = self.turn_to_messages(last_turn)
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: figure out whether we should add the tool responses to the last turn messages
|
|
||||||
last_turn_messages.extend(request.tool_responses)
|
last_turn_messages.extend(request.tool_responses)
|
||||||
|
|
||||||
# get the steps from the turn id
|
# get steps from the turn
|
||||||
steps = []
|
steps = last_turn.steps
|
||||||
steps = turns[-1].steps
|
|
||||||
|
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
@ -329,14 +259,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
input_messages = last_turn_messages
|
||||||
|
|
||||||
|
turn_id = request.turn_id
|
||||||
|
start_time = last_turn.started_at
|
||||||
|
else:
|
||||||
|
messages.extend(request.messages)
|
||||||
|
start_time = datetime.now().astimezone().isoformat()
|
||||||
|
input_messages = request.messages
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=request.turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
documents=request.documents if not is_resume else None,
|
||||||
|
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
|
@ -351,21 +291,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
assert output_message is not None
|
assert output_message is not None
|
||||||
|
|
||||||
last_turn_start_time = datetime.now().astimezone().isoformat()
|
|
||||||
if len(turns) > 0:
|
|
||||||
last_turn_start_time = turns[-1].started_at
|
|
||||||
|
|
||||||
turn = Turn(
|
turn = Turn(
|
||||||
turn_id=request.turn_id,
|
turn_id=turn_id,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
input_messages=last_turn_messages,
|
input_messages=input_messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=last_turn_start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
if output_message.tool_calls:
|
if output_message.tool_calls:
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue