feat(2/n): agent return turn awaiting input for tool calls (#1187)

# What does this PR do?
- 1/n: https://github.com/meta-llama/llama-stack/pull/1178

This is a backward compat change. Default `allow_resume_turn=False` to
work with client v0.1.3.

**Expected Behaviour**

1. User sends out `agent.create_turn`. When user receives a custom
client_tool_call. The last chunk should be
```python
chunk = AgentTurnResponseStreamChunk(
        event=AgentTurnResponseEvent(
        payload=AgentTurnResponseTurnAwaitingInputPayload(
            turn=turn,
       )
    )
)
```

There should be a `tool_execution` step. 

2. When user sees turn_pending as the event_type or if there's tool call
in the response. They will execute the client tools, and submits back
the tool response. E.g.

```python
client.continue_turn(
    [ToolResponseMessage(...)]
)
```

The agent will then send back & record `tool_execution` step finish on
server. and continue with the execution.


[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

> NOTE: This PR handles (1). Next PR handle (2). 

## Test Plan

- Testing w/ Client SDK 0.1.3
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct
```
<img width="1075" alt="image"
src="https://github.com/user-attachments/assets/c94e9cc6-a326-405c-a2ef-925dbc3ae58e"
/>



[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-21 11:47:21 -08:00 committed by GitHub
parent fa0dfdeac2
commit 36a1e8bc57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 248 additions and 29 deletions

View file

@ -2319,7 +2319,7 @@
"post": {
"responses": {
"200": {
"description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
"content": {
"application/json": {
"schema": {
@ -2337,11 +2337,12 @@
"tags": [
"Agents"
],
"description": "",
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to resume.",
"required": true,
"schema": {
"type": "string"
@ -2350,6 +2351,7 @@
{
"name": "session_id",
"in": "path",
"description": "The ID of the session to resume.",
"required": true,
"schema": {
"type": "string"
@ -2358,6 +2360,7 @@
{
"name": "turn_id",
"in": "path",
"description": "The ID of the turn to resume.",
"required": true,
"schema": {
"type": "string"
@ -4287,6 +4290,9 @@
},
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"allow_turn_resume": {
"type": "boolean"
}
},
"additionalProperties": false,
@ -8106,10 +8112,12 @@
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
}
},
"description": "The tool call responses to resume the turn with."
},
"stream": {
"type": "boolean"
"type": "boolean",
"description": "Whether to stream the response."
}
},
"additionalProperties": false,

View file

@ -1406,8 +1406,8 @@ paths:
responses:
'200':
description: >-
A single turn in an interaction with an Agentic System. **OR** streamed
agent turn completion response.
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
objects.
content:
application/json:
schema:
@ -1417,20 +1417,28 @@ paths:
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
tags:
- Agents
description: ''
description: >-
Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client
side tool calls, this endpoint can be used to submit the outputs from the
tool calls once they are ready.
parameters:
- name: agent_id
in: path
description: The ID of the agent to resume.
required: true
schema:
type: string
- name: session_id
in: path
description: The ID of the session to resume.
required: true
schema:
type: string
- name: turn_id
in: path
description: The ID of the turn to resume.
required: true
schema:
type: string
@ -2779,6 +2787,8 @@ components:
$ref: '#/components/schemas/AgentTool'
tool_config:
$ref: '#/components/schemas/ToolConfig'
allow_turn_resume:
type: boolean
additionalProperties: false
required:
- messages
@ -5242,8 +5252,11 @@ components:
type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with.
stream:
type: boolean
description: Whether to stream the response.
additionalProperties: false
required:
- tool_responses

View file

@ -296,6 +296,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
@ -352,6 +355,7 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
@ -365,7 +369,19 @@ class Agents(Protocol):
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",

View file

@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
@ -62,7 +64,11 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
if output_message.tool_calls and request.allow_turn_resume:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
if tool_call.tool_name in client_tools:
yield message
return
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -177,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents):
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
pass
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None

View file

@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
from llama_stack.apis.agents.agents import ToolChoice
from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice,
)
class TestClientTool(ClientTool):