feat: unify max_infer_iters in client/server agent loop (#1309)

# What does this PR do?

We currently use `max_infer_iters` in 2 different ways
1/ Server: track number of times 
2/ Client side: track number of times we send `resume_turn` request

This PR gets rid of the need of (2) and makes server track total number
of times we perform inference within a Turn

**NOTE**
The PR will assume StopReason is set to
- end_of_message: turn is not finished, we could be waiting for client
tool call responses
- end_of_turn: if the entire turn is finished and there's no more things
to be done.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_custom_tool_infinite_loop --inference-model "meta-llama/Llama-3.3-70B-Instruct"
```

[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-03 10:08:36 -08:00 committed by GitHub
parent 754feba61f
commit 7d111c7510
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 50 additions and 3 deletions

View file

@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
"client_tools": [client_tool.get_tool_definition()],
"max_infer_iters": 5,
}
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[
{
"role": "user",
"content": "Get the boiling point of polyjuice with a tool call.",
},
],
session_id=session_id,
stream=False,
)
num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps])
assert num_tool_calls <= 5
def test_tool_choice(llama_stack_client, agent_config):
def run_agent(tool_choice):
client_tool = get_boiling_point