diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index aa27f421c..88b6e9697 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -614,118 +614,133 @@ class ChatAgent(ShieldRunnerMixin): logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") input_messages = input_messages + [message] else: - logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}") - # 1. Start the tool execution step and progress - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - tool_call = message.tool_calls[0] - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - tool_call=tool_call, - delta=ToolCallDelta( - parse_status=ToolCallParseStatus.in_progress, - tool_call=tool_call, - ), - ) - ) - ) + input_messages = input_messages + [message] - # If tool is a client tool, yield CompletionMessage and return - if tool_call.tool_name in client_tools: - # NOTE: mark end_of_message to indicate to client that it may - # call the tool and continue the conversation with the tool's response. - message.stop_reason = StopReason.end_of_message + # Process tool calls in the message + client_tool_calls = [] + non_client_tool_calls = [] + + # Separate client and non-client tool calls + for tool_call in message.tool_calls: + if tool_call.tool_name in client_tools: + client_tool_calls.append(tool_call) + else: + non_client_tool_calls.append(tool_call) + + # Process non-client tool calls first + for tool_call in non_client_tool_calls: + step_id = str(uuid.uuid4()) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + ) + ) + ) + + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + delta=ToolCallDelta( + parse_status=ToolCallParseStatus.in_progress, + tool_call=tool_call, + ), + ) + ) + ) + + # Execute the tool call + async with tracing.span( + "tool_execution", + { + "tool_name": tool_call.tool_name, + "input": message.model_dump_json(), + }, + ) as span: + tool_execution_start_time = datetime.now(timezone.utc).isoformat() + tool_result = await self.execute_tool_call_maybe( + session_id, + tool_call, + ) + if tool_result.content is None: + raise ValueError( + f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" + ) + result_message = ToolResponseMessage( + call_id=tool_call.call_id, + content=tool_result.content, + ) + span.set_attribute("output", result_message.model_dump_json()) + + # Store tool execution step + tool_execution_step = ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[ + ToolResponse( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + metadata=tool_result.metadata, + ) + ], + started_at=tool_execution_start_time, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + # Yield the step completion event + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + step_details=tool_execution_step, + ) + ) + ) + + # Add the result message to input_messages for the next iteration + input_messages.append(result_message) + + # TODO: add tool-input touchpoint and a "start" event for this step also + # but that needs a lot more refactoring of Tool code potentially + if (type(result_message.content) is str) and ( + out_attachment := _interpret_content_as_attachment(result_message.content) + ): + # NOTE: when we push this message back to the model, the model may ignore the + # attached file path etc. since the model is trained to only provide a user message + # with the summary. We keep all generated attachments and then attach them to final message + output_attachments.append(out_attachment) + + # If there are client tool calls, yield a message with only those tool calls + if client_tool_calls: await self.storage.set_in_progress_tool_call_step( session_id, turn_id, ToolExecutionStep( step_id=step_id, turn_id=turn_id, - tool_calls=[tool_call], + tool_calls=client_tool_calls, tool_responses=[], started_at=datetime.now(timezone.utc).isoformat(), ), ) - yield message + + # Create a copy of the message with only client tool calls + client_message = message.model_copy(deep=True) + client_message.tool_calls = client_tool_calls + # NOTE: mark end_of_message to indicate to client that it may + # call the tool and continue the conversation with the tool's response. + client_message.stop_reason = StopReason.end_of_message + + # Yield the message with client tool calls + yield client_message return - # If tool is a builtin server tool, execute it - tool_name = tool_call.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - async with tracing.span( - "tool_execution", - { - "tool_name": tool_name, - "input": message.model_dump_json(), - }, - ) as span: - tool_execution_start_time = datetime.now(timezone.utc).isoformat() - tool_call = message.tool_calls[0] - tool_result = await self.execute_tool_call_maybe( - session_id, - tool_call, - ) - if tool_result.content is None: - raise ValueError( - f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" - ) - result_messages = [ - ToolResponseMessage( - call_id=tool_call.call_id, - content=tool_result.content, - ) - ] - assert len(result_messages) == 1, "Currently not supporting multiple messages" - result_message = result_messages[0] - span.set_attribute("output", result_message.model_dump_json()) - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[tool_call], - tool_responses=[ - ToolResponse( - call_id=result_message.call_id, - tool_name=tool_call.tool_name, - content=result_message.content, - metadata=tool_result.metadata, - ) - ], - started_at=tool_execution_start_time, - completed_at=datetime.now(timezone.utc).isoformat(), - ), - ) - ) - ) - - # TODO: add tool-input touchpoint and a "start" event for this step also - # but that needs a lot more refactoring of Tool code potentially - if (type(result_message.content) is str) and ( - out_attachment := _interpret_content_as_attachment(result_message.content) - ): - # NOTE: when we push this message back to the model, the model may ignore the - # attached file path etc. since the model is trained to only provide a user message - # with the summary. We keep all generated attachments and then attach them to final message - output_attachments.append(out_attachment) - - input_messages = input_messages + [message, result_message] - async def _initialize_tools( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index af0987fa8..e514e3781 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -227,13 +227,6 @@ class LlamaGuardShield: if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value): messages = messages[1:] - for i in range(1, len(messages)): - if messages[i].role == messages[i - 1].role: - for i, m in enumerate(messages): - print(f"{i}: {m.role}: {m.content}") - raise ValueError( - f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}" - ) return messages async def run(self, messages: List[Message]) -> RunShieldResponse: diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index ef0e8c05e..581cc9f45 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -584,7 +584,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf [(get_boiling_point, False), (get_boiling_point_with_metadata, True)], ) def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools): - client_tool, expectes_metadata = client_tools + client_tool, expects_metadata = client_tools agent_config = { **agent_config, "input_shields": [], @@ -610,7 +610,7 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co assert steps[0].step_type == "inference" assert steps[1].step_type == "tool_execution" assert steps[1].tool_calls[0].tool_name.startswith("get_boiling_point") - if expectes_metadata: + if expects_metadata: assert steps[1].tool_responses[0].metadata["source"] == "https://www.google.com" assert steps[2].step_type == "inference" @@ -622,3 +622,44 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co assert last_step_completed_at < step.started_at assert step.started_at < step.completed_at last_step_completed_at = step.completed_at + + +def test_multi_tool_calls(llama_stack_client_with_mocked_inference, agent_config): + if "gpt" not in agent_config["model"]: + pytest.xfail("Only tested on GPT models") + + agent_config = { + **agent_config, + "tools": [get_boiling_point], + } + + agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", + }, + ], + session_id=session_id, + stream=False, + ) + steps = response.steps + assert len(steps) == 7 + assert steps[0].step_type == "shield_call" + assert steps[1].step_type == "inference" + assert steps[2].step_type == "shield_call" + assert steps[3].step_type == "tool_execution" + assert steps[4].step_type == "shield_call" + assert steps[5].step_type == "inference" + assert steps[6].step_type == "shield_call" + + tool_execution_step = steps[3] + assert len(tool_execution_step.tool_calls) == 2 + assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") + assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") + + output = response.output_message.content.lower() + assert "-100" in output and "-212" in output