mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
feat(2/n): agent return turn awaiting input for tool calls (#1187)
# What does this PR do? - 1/n: https://github.com/meta-llama/llama-stack/pull/1178 This is a backward compat change. Default `allow_resume_turn=False` to work with client v0.1.3. **Expected Behaviour** 1. User sends out `agent.create_turn`. When user receives a custom client_tool_call. The last chunk should be ```python chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnAwaitingInputPayload( turn=turn, ) ) ) ``` There should be a `tool_execution` step. 2. When user sees turn_pending as the event_type or if there's tool call in the response. They will execute the client tools, and submits back the tool response. E.g. ```python client.continue_turn( [ToolResponseMessage(...)] ) ``` The agent will then send back & record `tool_execution` step finish on server. and continue with the execution. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) > NOTE: This PR handles (1). Next PR handle (2). ## Test Plan - Testing w/ Client SDK 0.1.3 ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="1075" alt="image" src="https://github.com/user-attachments/assets/c94e9cc6-a326-405c-a2ef-925dbc3ae58e" /> [//]: # (## Documentation)
This commit is contained in:
parent
fa0dfdeac2
commit
36a1e8bc57
7 changed files with 248 additions and 29 deletions
16
docs/_static/llama-stack-spec.html
vendored
16
docs/_static/llama-stack-spec.html
vendored
|
@ -2319,7 +2319,7 @@
|
|||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
|
||||
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
|
@ -2337,11 +2337,12 @@
|
|||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"description": "",
|
||||
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the agent to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
|
@ -2350,6 +2351,7 @@
|
|||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the session to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
|
@ -2358,6 +2360,7 @@
|
|||
{
|
||||
"name": "turn_id",
|
||||
"in": "path",
|
||||
"description": "The ID of the turn to resume.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
|
@ -4287,6 +4290,9 @@
|
|||
},
|
||||
"tool_config": {
|
||||
"$ref": "#/components/schemas/ToolConfig"
|
||||
},
|
||||
"allow_turn_resume": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8106,10 +8112,12 @@
|
|||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||
}
|
||||
},
|
||||
"description": "The tool call responses to resume the turn with."
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean"
|
||||
"type": "boolean",
|
||||
"description": "Whether to stream the response."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
19
docs/_static/llama-stack-spec.yaml
vendored
19
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1406,8 +1406,8 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A single turn in an interaction with an Agentic System. **OR** streamed
|
||||
agent turn completion response.
|
||||
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||
objects.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
|
@ -1417,20 +1417,28 @@ paths:
|
|||
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||
tags:
|
||||
- Agents
|
||||
description: ''
|
||||
description: >-
|
||||
Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client
|
||||
side tool calls, this endpoint can be used to submit the outputs from the
|
||||
tool calls once they are ready.
|
||||
parameters:
|
||||
- name: agent_id
|
||||
in: path
|
||||
description: The ID of the agent to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: session_id
|
||||
in: path
|
||||
description: The ID of the session to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: turn_id
|
||||
in: path
|
||||
description: The ID of the turn to resume.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
@ -2779,6 +2787,8 @@ components:
|
|||
$ref: '#/components/schemas/AgentTool'
|
||||
tool_config:
|
||||
$ref: '#/components/schemas/ToolConfig'
|
||||
allow_turn_resume:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
required:
|
||||
- messages
|
||||
|
@ -5242,8 +5252,11 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolResponseMessage'
|
||||
description: >-
|
||||
The tool call responses to resume the turn with.
|
||||
stream:
|
||||
type: boolean
|
||||
description: Whether to stream the response.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- tool_responses
|
||||
|
|
|
@ -296,6 +296,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
|
||||
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||
allow_turn_resume: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
|
@ -352,6 +355,7 @@ class Agents(Protocol):
|
|||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(
|
||||
|
@ -365,7 +369,19 @@ class Agents(Protocol):
|
|||
turn_id: str,
|
||||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
||||
:param agent_id: The ID of the agent to resume.
|
||||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
|
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResumeRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
InferenceStep,
|
||||
|
@ -62,7 +64,11 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for turn in turns:
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
if output_message.tool_calls and request.allow_turn_resume:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
messages.extend(request.tool_responses)
|
||||
|
||||
last_turn_messages = [
|
||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
|
||||
# get the steps from the turn id
|
||||
steps = []
|
||||
if len(turns) > 0:
|
||||
steps = turns[-1].steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
# we'll create a new tool execution step with current time
|
||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=x.call_id,
|
||||
tool_name=x.tool_name,
|
||||
content=x.content,
|
||||
)
|
||||
for x in request.tool_responses
|
||||
],
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=request.turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
last_turn_start_time = datetime.now()
|
||||
if len(turns) > 0:
|
||||
last_turn_start_time = turns[-1].started_at
|
||||
|
||||
turn = Turn(
|
||||
turn_id=request.turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=last_turn_messages,
|
||||
output_message=output_message,
|
||||
started_at=last_turn_start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
if output_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
|
@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
|
@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
allow_turn_resume: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
allow_turn_resume=allow_turn_resume,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
@ -177,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_responses: List[ToolResponseMessage],
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
pass
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
|||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
|
|
@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
|||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
from llama_stack_client.types.tool_def_param import Parameter
|
||||
|
||||
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
|
||||
from llama_stack.apis.agents.agents import ToolChoice
|
||||
from llama_stack.apis.agents.agents import (
|
||||
AgentConfig as Server__AgentConfig,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import (
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
|
||||
class TestClientTool(ClientTool):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue