This commit is contained in:
Xi Yan 2025-02-20 14:54:45 -08:00
parent 5eea2bc44d
commit 07c9222b6f

View file

@ -30,8 +30,8 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload, AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk, AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnPendingPayload,
AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnStartPayload,
Attachment, Attachment,
Document, Document,
@ -225,12 +225,11 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(), completed_at=datetime.now(),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls: if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnPendingPayload( payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn, turn=turn,
) )
) )
@ -243,6 +242,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
# only add to storage if turn is complete
await self.storage.add_turn_to_session(request.session_id, turn)
yield chunk yield chunk
@ -626,11 +627,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
log.info(f"{str(message)}") log.info(f"{str(message)}")
tool_call = message.tool_calls[0] # 1. Start the tool execution step and progress
if tool_call.tool_name in client_tools:
yield message
return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -640,6 +637,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
@ -654,6 +653,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value tool_name = tool_name.value