mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
persist pending tool execution
This commit is contained in:
parent
4923270122
commit
5e00e9f260
3 changed files with 69 additions and 1 deletions
|
@ -275,6 +275,36 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if len(turns) > 0:
|
||||
steps = turns[-1].steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in in_progress_tool_call_step.tool_responses
|
||||
],
|
||||
completed_at=datetime.now(),
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
|
@ -302,6 +332,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
last_turn_start_time = turns[-1].started_at
|
||||
last_turn_messages = self.turn_to_messages(turns[-1])
|
||||
|
||||
# add tool responses to the last turn messages
|
||||
last_turn_messages.extend(request.tool_responses)
|
||||
# filter out non User / Tool messages
|
||||
# TODO: should we just keep all message types in Turn.input_messages?
|
||||
last_turn_messages = [
|
||||
m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage)
|
||||
]
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
|
@ -739,6 +777,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
return
|
||||
|
||||
|
|
|
@ -178,6 +178,13 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnContinueRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
|||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue