forked from phoenix-oss/llama-stack-mirror
feat(1/n): api: unify agents for handling server & client tools (#1178)
# Problem
Our current Agent framework has discrepancies in definition on how we
handle server side and client side tools.
1. Server Tools: a single Turn is returned including `ToolExecutionStep`
in agenst
2. Client Tools: `create_agent_turn` is called in loop with client agent
lib yielding the agent chunk
ad6ffc63df/src/llama_stack_client/lib/agents/agent.py (L186-L211)
This makes it inconsistent to work with server & client tools. It also
complicates the logs to telemetry to get information about agents turn /
history for observability.
#### Principle
The same `turn_id` should be used to represent the steps required to
complete a user message including client tools.
## Solution
1. `AgentTurnResponseEventType.turn_awaiting_input` status to indicate
that the current turn is not completed, and awaiting tool input
2. `continue_agent_turn` endpoint to update agent turn with client's
tool response.
# What does this PR do?
- Skeleton API as example
## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]
- Just API update, no functionality change
```
llama stack run + client-sdk test
```
<img width="842" alt="image"
src="https://github.com/user-attachments/assets/7ac56b5f-f424-4632-9476-7e0f57555bc3"
/>
[//]: # (## Documentation)
This commit is contained in:
parent
992f865b2e
commit
0fe071764f
7 changed files with 454 additions and 21 deletions
113
docs/_static/llama-stack-spec.html
vendored
113
docs/_static/llama-stack-spec.html
vendored
|
@ -2315,6 +2315,70 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Agents"
|
||||||
|
],
|
||||||
|
"description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "agent_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the agent to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "session_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the session to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "turn_id",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The ID of the turn to resume.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ResumeAgentTurnRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
"/v1/eval/benchmarks/{benchmark_id}/jobs": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -4226,6 +4290,9 @@
|
||||||
},
|
},
|
||||||
"tool_config": {
|
"tool_config": {
|
||||||
"$ref": "#/components/schemas/ToolConfig"
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
|
},
|
||||||
|
"allow_turn_resume": {
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -4612,6 +4679,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -4621,7 +4691,8 @@
|
||||||
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
"step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload",
|
||||||
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
"step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload",
|
||||||
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
"turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload",
|
||||||
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
|
"turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload",
|
||||||
|
"turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4784,6 +4855,25 @@
|
||||||
"title": "AgentTurnResponseStreamChunk",
|
"title": "AgentTurnResponseStreamChunk",
|
||||||
"description": "streamed agent turn completion response."
|
"description": "streamed agent turn completion response."
|
||||||
},
|
},
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"event_type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "turn_awaiting_input",
|
||||||
|
"default": "turn_awaiting_input"
|
||||||
|
},
|
||||||
|
"turn": {
|
||||||
|
"$ref": "#/components/schemas/Turn"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"event_type",
|
||||||
|
"turn"
|
||||||
|
],
|
||||||
|
"title": "AgentTurnResponseTurnAwaitingInputPayload"
|
||||||
|
},
|
||||||
"AgentTurnResponseTurnCompletePayload": {
|
"AgentTurnResponseTurnCompletePayload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -8046,6 +8136,27 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"ResumeAgentTurnRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_responses": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ToolResponseMessage"
|
||||||
|
},
|
||||||
|
"description": "The tool call responses to resume the turn with."
|
||||||
|
},
|
||||||
|
"stream": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to stream the response."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"tool_responses"
|
||||||
|
],
|
||||||
|
"title": "ResumeAgentTurnRequest"
|
||||||
|
},
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
82
docs/_static/llama-stack-spec.yaml
vendored
82
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1401,6 +1401,53 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk
|
||||||
|
objects.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/AgentTurnResponseStreamChunk'
|
||||||
|
tags:
|
||||||
|
- Agents
|
||||||
|
description: >-
|
||||||
|
Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client
|
||||||
|
side tool calls, this endpoint can be used to submit the outputs from the
|
||||||
|
tool calls once they are ready.
|
||||||
|
parameters:
|
||||||
|
- name: agent_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the agent to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: session_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the session to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: turn_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the turn to resume.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ResumeAgentTurnRequest'
|
||||||
|
required: true
|
||||||
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
/v1/eval/benchmarks/{benchmark_id}/jobs:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
@ -2740,6 +2787,8 @@ components:
|
||||||
$ref: '#/components/schemas/AgentTool'
|
$ref: '#/components/schemas/AgentTool'
|
||||||
tool_config:
|
tool_config:
|
||||||
$ref: '#/components/schemas/ToolConfig'
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
allow_turn_resume:
|
||||||
|
type: boolean
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- messages
|
- messages
|
||||||
|
@ -2992,6 +3041,7 @@ components:
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
- $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
- $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: event_type
|
propertyName: event_type
|
||||||
mapping:
|
mapping:
|
||||||
|
@ -3000,6 +3050,7 @@ components:
|
||||||
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload'
|
||||||
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload'
|
||||||
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload'
|
||||||
|
turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload'
|
||||||
AgentTurnResponseStepCompletePayload:
|
AgentTurnResponseStepCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3106,6 +3157,21 @@ components:
|
||||||
- event
|
- event
|
||||||
title: AgentTurnResponseStreamChunk
|
title: AgentTurnResponseStreamChunk
|
||||||
description: streamed agent turn completion response.
|
description: streamed agent turn completion response.
|
||||||
|
"AgentTurnResponseTurnAwaitingInputPayload":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
event_type:
|
||||||
|
type: string
|
||||||
|
const: turn_awaiting_input
|
||||||
|
default: turn_awaiting_input
|
||||||
|
turn:
|
||||||
|
$ref: '#/components/schemas/Turn'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- event_type
|
||||||
|
- turn
|
||||||
|
title: >-
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload
|
||||||
AgentTurnResponseTurnCompletePayload:
|
AgentTurnResponseTurnCompletePayload:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5205,6 +5271,22 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
ResumeAgentTurnRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_responses:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ToolResponseMessage'
|
||||||
|
description: >-
|
||||||
|
The tool call responses to resume the turn with.
|
||||||
|
stream:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to stream the response.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- tool_responses
|
||||||
|
title: ResumeAgentTurnRequest
|
||||||
RunEvalRequest:
|
RunEvalRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
turn_start = "turn_start"
|
turn_start = "turn_start"
|
||||||
turn_complete = "turn_complete"
|
turn_complete = "turn_complete"
|
||||||
|
turn_awaiting_input = "turn_awaiting_input"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||||
|
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||||
|
)
|
||||||
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = register_schema(
|
AgentTurnResponseEventPayload = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[
|
Union[
|
||||||
|
@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema(
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
],
|
],
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
],
|
],
|
||||||
|
@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
|
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
||||||
|
allow_turn_resume: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
session_id: str
|
||||||
|
turn_id: str
|
||||||
|
tool_responses: List[ToolResponseMessage]
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStreamChunk(BaseModel):
|
class AgentTurnResponseStreamChunk(BaseModel):
|
||||||
|
@ -333,8 +355,34 @@ class Agents(Protocol):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
|
@webmethod(
|
||||||
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||||
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
|
||||||
|
:param agent_id: The ID of the agent to resume.
|
||||||
|
:param session_id: The ID of the session to resume.
|
||||||
|
:param turn_id: The ID of the turn to resume.
|
||||||
|
:param tool_responses: The tool call responses to resume the turn with.
|
||||||
|
:param stream: Whether to stream the response.
|
||||||
|
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||||
method="GET",
|
method="GET",
|
||||||
|
|
|
@ -30,8 +30,10 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
|
@ -62,7 +64,11 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for turn in turns:
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
if output_message.tool_calls and request.allow_turn_resume:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn_messages = [
|
||||||
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the steps from the turn id
|
||||||
|
steps = []
|
||||||
|
if len(turns) > 0:
|
||||||
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
# we'll create a new tool execution step with current time
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in request.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=now,
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
turn=turn,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_message = None
|
||||||
|
async for chunk in self.run(
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
input_messages=messages,
|
||||||
|
sampling_params=self.agent_config.sampling_params,
|
||||||
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
# 1. Start the tool execution step and progress
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# If tool is a builtin server tool, execute it
|
||||||
tool_name = tool_call.tool_name
|
tool_name = tool_call.tool_name
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
allow_turn_resume: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
|
allow_turn_resume=allow_turn_resume,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
@ -169,6 +172,34 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
async def resume_agent_turn(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_id: str,
|
||||||
|
tool_responses: List[ToolResponseMessage],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = AgentTurnResumeRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnResumeRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.resume_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
turn = json.loads(turn)
|
turn = json.loads(turn)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
|
@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
from llama_stack_client.types.tool_def_param import Parameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
|
from llama_stack.apis.agents.agents import (
|
||||||
from llama_stack.apis.agents.agents import ToolChoice
|
AgentConfig as Server__AgentConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.agents.agents import (
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue