feat(3/n): agent resume_turn (#1194)

# What does this PR do?
- https://github.com/meta-llama/llama-stack/pull/1178
- https://github.com/meta-llama/llama-stack/pull/1187
- https://github.com/meta-llama/llama-stack/pull/1194

**client changes**
- https://github.com/meta-llama/llama-stack-client-python/pull/157
- https://github.com/meta-llama/llama-stack-client-python/pull/158

## Test Plan
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct
```

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct
```

<img width="1080" alt="image"
src="https://github.com/user-attachments/assets/6c48e062-5d00-4fed-98de-ecca627e5195"
/>

**llama-stack-apps**
```
python -m examples.agents.react_agent localhost 8321
```

- Test with script:
https://gist.github.com/yanxi0830/f2e407527f468998a700cd29fd271b15

**Output Before**: we have 2 `turn_id` with 2 turns
- https://gist.github.com/yanxi0830/9fbd7a80fcddc784a28c59d4a9c1d943

**Output After**: we have 1 `turn_id`, the final turn have all 3 steps
- https://gist.github.com/yanxi0830/17754d56d08ccbeaec419b693137500c

<img width="622" alt="image"
src="https://github.com/user-attachments/assets/ed7f42ca-41c8-4ee9-b6fa-2299901cafb4"
/>

**Telemetry**
<img width="1389" alt="image"
src="https://github.com/user-attachments/assets/031bf6ab-959c-46a7-aa85-70a511cba296"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-02-21 11:43:48 -08:00 committed by GitHub
parent 4f2427c6c8
commit 4dc7f05a2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 168 additions and 11 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
@ -156,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -168,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
@ -246,6 +249,119 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
# get the steps from the turn id
steps = []
if len(turns) > 0:
steps = turns[-1].steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
self,
session_id: str,
@ -636,7 +752,6 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -654,6 +769,17 @@ class ChatAgent(ShieldRunnerMixin):
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
yield message
return

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
@ -179,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents):
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
pass
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")

View file

@ -12,7 +12,7 @@ from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
@ -84,3 +84,15 @@ class AgentPersistence:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None