persist pending tool execution

This commit is contained in:
Xi Yan 2025-02-20 19:33:21 -08:00
parent 4923270122
commit 5e00e9f260
3 changed files with 69 additions and 1 deletions

View file

@ -275,6 +275,36 @@ class ChatAgent(ShieldRunnerMixin):
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in in_progress_tool_call_step.tool_responses
],
completed_at=datetime.now(),
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
@ -302,6 +332,14 @@ class ChatAgent(ShieldRunnerMixin):
last_turn_start_time = turns[-1].started_at
last_turn_messages = self.turn_to_messages(turns[-1])
# add tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses)
# filter out non User / Tool messages
# TODO: should we just keep all message types in Turn.input_messages?
last_turn_messages = [
m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage)
]
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
@ -739,6 +777,17 @@ class ChatAgent(ShieldRunnerMixin):
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return

View file

@ -178,6 +178,13 @@ class MetaReferenceAgentsImpl(Agents):
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnContinueRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None