feat: unify max_infer_iters in client/server agent loop (#1309)

# What does this PR do?

We currently use `max_infer_iters` in 2 different ways
1/ Server: track number of times 
2/ Client side: track number of times we send `resume_turn` request

This PR gets rid of the need of (2) and makes server track total number
of times we perform inference within a Turn

**NOTE**
The PR will assume StopReason is set to
- end_of_message: turn is not finished, we could be waiting for client
tool call responses
- end_of_turn: if the entire turn is finished and there's no more things
to be done.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_custom_tool_infinite_loop --inference-model "meta-llama/Llama-3.3-70B-Instruct"
```

[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-03 10:08:36 -08:00 committed by GitHub
parent 754feba61f
commit 7d111c7510
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 50 additions and 3 deletions

View file

@ -540,7 +540,8 @@ class ChatAgent(ShieldRunnerMixin):
output_attachments = []
n_iter = 0
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
@ -627,6 +628,9 @@ class ChatAgent(ShieldRunnerMixin):
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
stop_reason = stop_reason or StopReason.out_of_tokens
# If tool calls are parsed successfully,
@ -662,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin):
if n_iter >= self.agent_config.max_infer_iters:
log.info("Done with MAX iterations, exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
yield message
break
@ -711,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin):
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
message.stop_reason = StopReason.end_of_message
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
@ -796,8 +806,6 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message]
n_iter += 1
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]: