feat(1/n): api: unify agents for handling server & client tools (#1178)

# Problem

Our current Agent framework has discrepancies in definition on how we
handle server side and client side tools.

1. Server Tools: a single Turn is returned including `ToolExecutionStep`
in agenst
2. Client Tools: `create_agent_turn` is called in loop with client agent
lib yielding the agent chunk

ad6ffc63df/src/llama_stack_client/lib/agents/agent.py (L186-L211)

This makes it inconsistent to work with server & client tools. It also
complicates the logs to telemetry to get information about agents turn /
history for observability.

#### Principle
The same `turn_id` should be used to represent the steps required to
complete a user message including client tools.

## Solution

1. `AgentTurnResponseEventType.turn_awaiting_input` status to indicate
that the current turn is not completed, and awaiting tool input
2. `continue_agent_turn` endpoint to update agent turn with client's
tool response.


# What does this PR do?
- Skeleton API as example

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

- Just API update, no functionality change
```
llama stack run + client-sdk test
```

<img width="842" alt="image"
src="https://github.com/user-attachments/assets/7ac56b5f-f424-4632-9476-7e0f57555bc3"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-21 11:48:27 -08:00 committed by GitHub
parent 992f865b2e
commit 0fe071764f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 454 additions and 21 deletions

View file

@ -2315,6 +2315,70 @@
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
"200": {
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Turn"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
}
}
}
}
},
"tags": [
"Agents"
],
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
"parameters": [
{
"name": "agent_id",
"in": "path",
"description": "The ID of the agent to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "session_id",
"in": "path",
"description": "The ID of the session to resume.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "turn_id",
"in": "path",
"description": "The ID of the turn to resume.",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
}
}
},
"required": true
}
}
},
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
"post": {
"responses": {
@ -4226,6 +4290,9 @@
},
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"allow_turn_resume": {
"type": "boolean"
}
},
"additionalProperties": false,
@ -4612,6 +4679,9 @@
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
},
{
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
],
"discriminator": {
@ -4621,7 +4691,8 @@
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
}
}
},
@ -4784,6 +4855,25 @@
"title": "AgentTurnResponseStreamChunk",
"description": "streamed agent turn completion response."
},
"AgentTurnResponseTurnAwaitingInputPayload": {
"type": "object",
"properties": {
"event_type": {
"type": "string",
"const": "turn_awaiting_input",
"default": "turn_awaiting_input"
},
"turn": {
"$ref": "#/components/schemas/Turn"
}
},
"additionalProperties": false,
"required": [
"event_type",
"turn"
],
"title": "AgentTurnResponseTurnAwaitingInputPayload"
},
"AgentTurnResponseTurnCompletePayload": {
"type": "object",
"properties": {
@ -8046,6 +8136,27 @@
],
"title": "RegisterVectorDbRequest"
},
"ResumeAgentTurnRequest": {
"type": "object",
"properties": {
"tool_responses": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
},
"description": "The tool call responses to resume the turn with."
},
"stream": {
"type": "boolean",
"description": "Whether to stream the response."
}
},
"additionalProperties": false,
"required": [
"tool_responses"
],
"title": "ResumeAgentTurnRequest"
},
"RunEvalRequest": {
"type": "object",
"properties": {

View file

@ -1401,6 +1401,53 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
'200':
description: >-
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
objects.
content:
application/json:
schema:
$ref: '#/components/schemas/Turn'
text/event-stream:
schema:
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
tags:
- Agents
description: >-
Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client
side tool calls, this endpoint can be used to submit the outputs from the
tool calls once they are ready.
parameters:
- name: agent_id
in: path
description: The ID of the agent to resume.
required: true
schema:
type: string
- name: session_id
in: path
description: The ID of the session to resume.
required: true
schema:
type: string
- name: turn_id
in: path
description: The ID of the turn to resume.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ResumeAgentTurnRequest'
required: true
/v1/eval/benchmarks/{benchmark_id}/jobs:
post:
responses:
@ -2740,6 +2787,8 @@ components:
$ref: '#/components/schemas/AgentTool'
tool_config:
$ref: '#/components/schemas/ToolConfig'
allow_turn_resume:
type: boolean
additionalProperties: false
required:
- messages
@ -2992,6 +3041,7 @@ components:
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
discriminator:
propertyName: event_type
mapping:
@ -3000,6 +3050,7 @@ components:
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
AgentTurnResponseStepCompletePayload:
type: object
properties:
@ -3106,6 +3157,21 @@ components:
- event
title: AgentTurnResponseStreamChunk
description: streamed agent turn completion response.
"AgentTurnResponseTurnAwaitingInputPayload":
type: object
properties:
event_type:
type: string
const: turn_awaiting_input
default: turn_awaiting_input
turn:
$ref: '#/components/schemas/Turn'
additionalProperties: false
required:
- event_type
- turn
title: >-
AgentTurnResponseTurnAwaitingInputPayload
AgentTurnResponseTurnCompletePayload:
type: object
properties:
@ -5205,6 +5271,22 @@ components:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
ResumeAgentTurnRequest:
type: object
properties:
tool_responses:
type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with.
stream:
type: boolean
description: Whether to stream the response.
additionalProperties: false
required:
- tool_responses
title: ResumeAgentTurnRequest
RunEvalRequest:
type: object
properties:

View file

@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
turn: Turn
AgentTurnResponseEventPayload = register_schema(
Annotated[
Union[
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"),
],
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
stream: Optional[bool] = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
@ -333,8 +355,34 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",

View file

@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
@ -62,7 +64,11 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
if output_message.tool_calls and request.allow_turn_resume:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
if tool_call.tool_name in client_tools:
yield message
return
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -169,6 +172,34 @@ class MetaReferenceAgentsImpl(Agents):
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None

View file

@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
from llama_stack.apis.agents.agents import ToolChoice
from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice,
)
class TestClientTool(ClientTool):