mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
feat(2/n): agent return turn awaiting input for tool calls (#1187)
# What does this PR do? - 1/n: https://github.com/meta-llama/llama-stack/pull/1178 This is a backward compat change. Default `allow_resume_turn=False` to work with client v0.1.3. **Expected Behaviour** 1. User sends out `agent.create_turn`. When user receives a custom client_tool_call. The last chunk should be ```python chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnAwaitingInputPayload( turn=turn, ) ) ) ``` There should be a `tool_execution` step. 2. When user sees turn_pending as the event_type or if there's tool call in the response. They will execute the client tools, and submits back the tool response. E.g. ```python client.continue_turn( [ToolResponseMessage(...)] ) ``` The agent will then send back & record `tool_execution` step finish on server. and continue with the execution. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) > NOTE: This PR handles (1). Next PR handle (2). ## Test Plan - Testing w/ Client SDK 0.1.3 ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="1075" alt="image" src="https://github.com/user-attachments/assets/c94e9cc6-a326-405c-a2ef-925dbc3ae58e" /> [//]: # (## Documentation)
This commit is contained in:
parent
fa0dfdeac2
commit
36a1e8bc57
7 changed files with 248 additions and 29 deletions
16
docs/_static/llama-stack-spec.html
vendored
16
docs/_static/llama-stack-spec.html
vendored
|
@ -2319,7 +2319,7 @@
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
|
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
|
@ -2337,11 +2337,12 @@
|
||||||
"tags": [
|
"tags": [
|
||||||
"Agents"
|
"Agents"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"name": "agent_id",
|
"name": "agent_id",
|
||||||
"in": "path",
|
"in": "path",
|
||||||
|
"description": "The ID of the agent to resume.",
|
||||||
"required": true,
|
"required": true,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -2350,6 +2351,7 @@
|
||||||
{
|
{
|
||||||
"name": "session_id",
|
"name": "session_id",
|
||||||
"in": "path",
|
"in": "path",
|
||||||
|
"description": "The ID of the session to resume.",
|
||||||
"required": true,
|
"required": true,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -2358,6 +2360,7 @@
|
||||||
{
|
{
|
||||||
"name": "turn_id",
|
"name": "turn_id",
|
||||||
"in": "path",
|
"in": "path",
|
||||||
|
"description": "The ID of the turn to resume.",
|
||||||
"required": true,
|
"required": true,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
@ -4287,6 +4290,9 @@
|
||||||
},
|
},
|
||||||
"tool_config": {
|
"tool_config": {
|
||||||
"$ref": "#/components/schemas/ToolConfig"
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
|
},
|
||||||
|
"allow_turn_resume": {
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -8106,10 +8112,12 @@
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/ToolResponseMessage"
|
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||||
}
|
},
|
||||||
|
"description": "The tool call responses to resume the turn with."
|
||||||
},
|
},
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean",
|
||||||
|
"description": "Whether to stream the response."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
19
docs/_static/llama-stack-spec.yaml
vendored
19
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1406,8 +1406,8 @@ paths:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: >-
|
description: >-
|
||||||
A single turn in an interaction with an Agentic System. **OR** streamed
|
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||||
agent turn completion response.
|
objects.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
|
@ -1417,20 +1417,28 @@ paths:
|
||||||
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
description: ''
|
description: >-
|
||||||
|
Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client
|
||||||
|
side tool calls, this endpoint can be used to submit the outputs from the
|
||||||
|
tool calls once they are ready.
|
||||||
parameters:
|
parameters:
|
||||||
- name: agent_id
|
- name: agent_id
|
||||||
in: path
|
in: path
|
||||||
|
description: The ID of the agent to resume.
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: session_id
|
- name: session_id
|
||||||
in: path
|
in: path
|
||||||
|
description: The ID of the session to resume.
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: turn_id
|
- name: turn_id
|
||||||
in: path
|
in: path
|
||||||
|
description: The ID of the turn to resume.
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
@ -2779,6 +2787,8 @@ components:
|
||||||
$ref: '#/components/schemas/AgentTool'
|
$ref: '#/components/schemas/AgentTool'
|
||||||
tool_config:
|
tool_config:
|
||||||
$ref: '#/components/schemas/ToolConfig'
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
allow_turn_resume:
|
||||||
|
type: boolean
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- messages
|
- messages
|
||||||
|
@ -5242,8 +5252,11 @@ components:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/ToolResponseMessage'
|
$ref: '#/components/schemas/ToolResponseMessage'
|
||||||
|
description: >-
|
||||||
|
The tool call responses to resume the turn with.
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
description: Whether to stream the response.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- tool_responses
|
- tool_responses
|
||||||
|
|
|
@ -296,6 +296,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
|
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||||
|
allow_turn_resume: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResumeRequest(BaseModel):
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
@ -352,6 +355,7 @@ class Agents(Protocol):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
|
@ -365,7 +369,19 @@ class Agents(Protocol):
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponseMessage],
|
tool_responses: List[ToolResponseMessage],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||||
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to resume.
|
||||||
|
:param session_id: The ID of the session to resume.
|
||||||
|
:param turn_id: The ID of the turn to resume.
|
||||||
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
|
:param stream: Whether to stream the response.
|
||||||
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
|
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
|
@ -62,7 +64,11 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for turn in turns:
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
if output_message.tool_calls and request.allow_turn_resume:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn_messages = [
|
||||||
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the steps from the turn id
|
||||||
|
steps = []
|
||||||
|
if len(turns) > 0:
|
||||||
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
# we'll create a new tool execution step with current time
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in request.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=now,
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
turn=turn,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_message = None
|
||||||
|
async for chunk in self.run(
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
input_messages=messages,
|
||||||
|
sampling_params=self.agent_config.sampling_params,
|
||||||
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
# 1. Start the tool execution step and progress
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# If tool is a builtin server tool, execute it
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
|
allow_turn_resume=allow_turn_resume,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
@ -177,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_responses: List[ToolResponseMessage],
|
tool_responses: List[ToolResponseMessage],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
pass
|
request = AgentTurnResumeRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnResumeRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.resume_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
|
@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
from llama_stack_client.types.tool_def_param import Parameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
|
from llama_stack.apis.agents.agents import (
|
||||||
from llama_stack.apis.agents.agents import ToolChoice
|
AgentConfig as Server__AgentConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.agents.agents import (
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue