mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
persist pending tool execution
This commit is contained in:
parent
4923270122
commit
5e00e9f260
3 changed files with 69 additions and 1 deletions
|
@ -275,6 +275,36 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if len(turns) > 0:
|
if len(turns) > 0:
|
||||||
steps = turns[-1].steps
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in in_progress_tool_call_step.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else datetime.now()),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
|
@ -302,6 +332,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
last_turn_start_time = turns[-1].started_at
|
last_turn_start_time = turns[-1].started_at
|
||||||
last_turn_messages = self.turn_to_messages(turns[-1])
|
last_turn_messages = self.turn_to_messages(turns[-1])
|
||||||
|
|
||||||
|
# add tool responses to the last turn messages
|
||||||
|
last_turn_messages.extend(request.tool_responses)
|
||||||
|
# filter out non User / Tool messages
|
||||||
|
# TODO: should we just keep all message types in Turn.input_messages?
|
||||||
|
last_turn_messages = [
|
||||||
|
m for m in last_turn_messages if isinstance(m, UserMessage) or isinstance(m, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
turn = Turn(
|
turn = Turn(
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
|
@ -739,6 +777,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -178,6 +178,13 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_responses: List[ToolResponseMessage],
|
tool_responses: List[ToolResponseMessage],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
request = AgentTurnContinueRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._continue_agent_turn_streaming(request)
|
return self._continue_agent_turn_streaming(request)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue