diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index d4f47c837..66c9ab829 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -698,8 +698,8 @@ def test_multi_tool_calls(llama_stack_client, agent_config): ) steps = response.steps - has_input_shield = True if agent_config.get("input_shields", None) else False - has_output_shield = True if agent_config.get("output_shields", None) else False + has_input_shield = agent_config.get("input_shields") + has_output_shield = agent_config.get("output_shields") assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0) if has_input_shield: assert steps[0].step_type == "shield_call"