mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
feat(3/n): agent resume_turn (#1194)
# What does this PR do? - https://github.com/meta-llama/llama-stack/pull/1178 - https://github.com/meta-llama/llama-stack/pull/1187 - https://github.com/meta-llama/llama-stack/pull/1194 **client changes** - https://github.com/meta-llama/llama-stack-client-python/pull/157 - https://github.com/meta-llama/llama-stack-client-python/pull/158 ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/client-sdk/agents/test_agents.py --inference-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="1080" alt="image" src="https://github.com/user-attachments/assets/6c48e062-5d00-4fed-98de-ecca627e5195" /> **llama-stack-apps** ``` python -m examples.agents.react_agent localhost 8321 ``` - Test with script: https://gist.github.com/yanxi0830/f2e407527f468998a700cd29fd271b15 **Output Before**: we have 2 `turn_id` with 2 turns - https://gist.github.com/yanxi0830/9fbd7a80fcddc784a28c59d4a9c1d943 **Output After**: we have 1 `turn_id`, the final turn have all 3 steps - https://gist.github.com/yanxi0830/17754d56d08ccbeaec419b693137500c <img width="622" alt="image" src="https://github.com/user-attachments/assets/ed7f42ca-41c8-4ee9-b6fa-2299901cafb4" /> **Telemetry** <img width="1389" alt="image" src="https://github.com/user-attachments/assets/031bf6ab-959c-46a7-aa85-70a511cba296" /> [//]: # (## Documentation)
This commit is contained in:
parent
4f2427c6c8
commit
4dc7f05a2d
3 changed files with 168 additions and 11 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
|
@ -156,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for turn in turns:
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -168,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
@ -246,6 +249,119 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||||
|
with tracing.span("resume_turn") as span:
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn_messages = [
|
||||||
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the steps from the turn id
|
||||||
|
steps = []
|
||||||
|
if len(turns) > 0:
|
||||||
|
steps = turns[-1].steps
|
||||||
|
|
||||||
|
# mark tool execution step as complete
|
||||||
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
# we'll create a new tool execution step with current time
|
||||||
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
tool_execution_step = ToolExecutionStep(
|
||||||
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id=x.call_id,
|
||||||
|
tool_name=x.tool_name,
|
||||||
|
content=x.content,
|
||||||
|
)
|
||||||
|
for x in request.tool_responses
|
||||||
|
],
|
||||||
|
completed_at=now,
|
||||||
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||||
|
)
|
||||||
|
steps.append(tool_execution_step)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=tool_execution_step.step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_message = None
|
||||||
|
async for chunk in self.run(
|
||||||
|
session_id=request.session_id,
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
input_messages=messages,
|
||||||
|
sampling_params=self.agent_config.sampling_params,
|
||||||
|
stream=request.stream,
|
||||||
|
):
|
||||||
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
assert output_message is not None
|
||||||
|
|
||||||
|
last_turn_start_time = datetime.now()
|
||||||
|
if len(turns) > 0:
|
||||||
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
turn = Turn(
|
||||||
|
turn_id=request.turn_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
input_messages=last_turn_messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=last_turn_start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -636,7 +752,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -654,6 +769,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
@ -179,7 +180,25 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_responses: List[ToolResponseMessage],
|
tool_responses: List[ToolResponseMessage],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
pass
|
request = AgentTurnResumeRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_responses=tool_responses,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._continue_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _continue_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnResumeRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
|
async for event in agent.resume_turn(request):
|
||||||
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -84,3 +84,15 @@ class AgentPersistence:
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=step.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue