mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat(agent): support multiple tool groups (#1556)
Summary: closes #1488 Test Plan: added new integration test ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model openai/gpt-4o-mini ``` --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1556). * __->__ #1556 * #1550
This commit is contained in:
parent
c23a7af5d6
commit
37f155e41d
3 changed files with 157 additions and 108 deletions
|
@ -614,8 +614,21 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
input_messages = input_messages + [message]
|
||||||
# 1. Start the tool execution step and progress
|
|
||||||
|
# Process tool calls in the message
|
||||||
|
client_tool_calls = []
|
||||||
|
non_client_tool_calls = []
|
||||||
|
|
||||||
|
# Separate client and non-client tool calls
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
client_tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
non_client_tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
# Process non-client tool calls first
|
||||||
|
for tool_call in non_client_tool_calls:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -625,13 +638,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
tool_call=tool_call,
|
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
|
@ -640,38 +652,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# Execute the tool call
|
||||||
if tool_call.tool_name in client_tools:
|
|
||||||
# NOTE: mark end_of_message to indicate to client that it may
|
|
||||||
# call the tool and continue the conversation with the tool's response.
|
|
||||||
message.stop_reason = StopReason.end_of_message
|
|
||||||
await self.storage.set_in_progress_tool_call_step(
|
|
||||||
session_id,
|
|
||||||
turn_id,
|
|
||||||
ToolExecutionStep(
|
|
||||||
step_id=step_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
tool_calls=[tool_call],
|
|
||||||
tool_responses=[],
|
|
||||||
started_at=datetime.now(timezone.utc).isoformat(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# If tool is a builtin server tool, execute it
|
|
||||||
tool_name = tool_call.tool_name
|
|
||||||
if isinstance(tool_name, BuiltinTool):
|
|
||||||
tool_name = tool_name.value
|
|
||||||
async with tracing.span(
|
async with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_call.tool_name,
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
@ -680,40 +669,43 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||||
)
|
)
|
||||||
result_messages = [
|
result_message = ToolResponseMessage(
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
call_id=tool_call.call_id,
|
||||||
content=tool_result.content,
|
content=tool_result.content,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
|
||||||
result_message = result_messages[0]
|
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
span.set_attribute("output", result_message.model_dump_json())
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
# Store tool execution step
|
||||||
event=AgentTurnResponseEvent(
|
tool_execution_step = ToolExecutionStep(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=ToolExecutionStep(
|
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=[tool_call],
|
tool_calls=[tool_call],
|
||||||
tool_responses=[
|
tool_responses=[
|
||||||
ToolResponse(
|
ToolResponse(
|
||||||
call_id=result_message.call_id,
|
call_id=tool_call.call_id,
|
||||||
tool_name=tool_call.tool_name,
|
tool_name=tool_call.tool_name,
|
||||||
content=result_message.content,
|
content=tool_result.content,
|
||||||
metadata=tool_result.metadata,
|
metadata=tool_result.metadata,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
),
|
)
|
||||||
|
|
||||||
|
# Yield the step completion event
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add the result message to input_messages for the next iteration
|
||||||
|
input_messages.append(result_message)
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
if (type(result_message.content) is str) and (
|
if (type(result_message.content) is str) and (
|
||||||
|
@ -724,7 +716,30 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
output_attachments.append(out_attachment)
|
output_attachments.append(out_attachment)
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
# If there are client tool calls, yield a message with only those tool calls
|
||||||
|
if client_tool_calls:
|
||||||
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=client_tool_calls,
|
||||||
|
tool_responses=[],
|
||||||
|
started_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a copy of the message with only client tool calls
|
||||||
|
client_message = message.model_copy(deep=True)
|
||||||
|
client_message.tool_calls = client_tool_calls
|
||||||
|
# NOTE: mark end_of_message to indicate to client that it may
|
||||||
|
# call the tool and continue the conversation with the tool's response.
|
||||||
|
client_message.stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
# Yield the message with client tool calls
|
||||||
|
yield client_message
|
||||||
|
return
|
||||||
|
|
||||||
async def _initialize_tools(
|
async def _initialize_tools(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -227,13 +227,6 @@ class LlamaGuardShield:
|
||||||
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
|
|
||||||
for i in range(1, len(messages)):
|
|
||||||
if messages[i].role == messages[i - 1].role:
|
|
||||||
for i, m in enumerate(messages):
|
|
||||||
print(f"{i}: {m.role}: {m.content}")
|
|
||||||
raise ValueError(
|
|
||||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
|
||||||
)
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
|
|
|
@ -584,7 +584,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
)
|
)
|
||||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
|
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
|
||||||
client_tool, expectes_metadata = client_tools
|
client_tool, expects_metadata = client_tools
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
|
@ -610,7 +610,7 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
||||||
assert steps[0].step_type == "inference"
|
assert steps[0].step_type == "inference"
|
||||||
assert steps[1].step_type == "tool_execution"
|
assert steps[1].step_type == "tool_execution"
|
||||||
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
|
assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||||
if expectes_metadata:
|
if expects_metadata:
|
||||||
assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
|
assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com"
|
||||||
assert steps[2].step_type == "inference"
|
assert steps[2].step_type == "inference"
|
||||||
|
|
||||||
|
@ -622,3 +622,44 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
||||||
assert last_step_completed_at < step.started_at
|
assert last_step_completed_at < step.started_at
|
||||||
assert step.started_at < step.completed_at
|
assert step.started_at < step.completed_at
|
||||||
last_step_completed_at = step.completed_at
|
last_step_completed_at = step.completed_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_tool_calls(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
|
if "gpt" not in agent_config["model"]:
|
||||||
|
pytest.xfail("Only tested on GPT models")
|
||||||
|
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"tools": [get_boiling_point],
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
steps = response.steps
|
||||||
|
assert len(steps) == 7
|
||||||
|
assert steps[0].step_type == "shield_call"
|
||||||
|
assert steps[1].step_type == "inference"
|
||||||
|
assert steps[2].step_type == "shield_call"
|
||||||
|
assert steps[3].step_type == "tool_execution"
|
||||||
|
assert steps[4].step_type == "shield_call"
|
||||||
|
assert steps[5].step_type == "inference"
|
||||||
|
assert steps[6].step_type == "shield_call"
|
||||||
|
|
||||||
|
tool_execution_step = steps[3]
|
||||||
|
assert len(tool_execution_step.tool_calls) == 2
|
||||||
|
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
|
||||||
|
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")
|
||||||
|
|
||||||
|
output = response.output_message.content.lower()
|
||||||
|
assert "-100" in output and "-212" in output
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue