feat: support ClientTool output metadata (#1426)

# Summary:
Client side change in
https://github.com/meta-llama/llama-stack-client-python/pull/180
Changes the resume_turn API to accept `ToolResponse` instead of
`ToolResponseMessage`:
1. `ToolResponse` contains `metadata`
2. `ToolResponseMessage` is a concept for model inputs. Here we are just
submitting the outputs of tool execution.

# Test Plan:
Ran integration tests with newly added test using client tool with
metadata

LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --record-responses
This commit is contained in:
ehhuang 2025-03-05 14:30:27 -08:00 committed by GitHub
parent ac717f38dc
commit 6cf79437b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 3984 additions and 2172 deletions

View file

@ -216,13 +216,25 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
messages.extend(request.tool_responses)
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(request.tool_responses)
last_turn_messages.extend(tool_response_messages)
# get steps from the turn
steps = last_turn.steps
@ -238,14 +250,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
tool_responses=tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)