From 7d111c7510d59010ece6eee3dc84c4c05d217d7d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Mar 2025 10:08:36 -0800 Subject: [PATCH] feat: unify max_infer_iters in client/server agent loop (#1309) # What does this PR do? We currently use `max_infer_iters` in 2 different ways 1/ Server: track number of times 2/ Client side: track number of times we send `resume_turn` request This PR gets rid of the need of (2) and makes server track total number of times we perform inference within a Turn **NOTE** The PR will assume StopReason is set to - end_of_message: turn is not finished, we could be waiting for client tool call responses - end_of_turn: if the entire turn is finished and there's no more things to be done. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_custom_tool_infinite_loop --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` [//]: # (## Documentation) --- .../agents/meta_reference/agent_instance.py | 14 +++++++--- .../agents/meta_reference/persistence.py | 12 +++++++++ tests/client-sdk/agents/test_agents.py | 27 +++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index be7a6bf3c..34a9db630 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -540,7 +540,8 @@ class ChatAgent(ShieldRunnerMixin): output_attachments = [] - n_iter = 0 + n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 + # Build a map of custom tools to their definitions for faster lookup client_tools = {} for tool in self.agent_config.client_tools: @@ -627,6 +628,9 @@ class ChatAgent(ShieldRunnerMixin): ) span.set_attribute("output", output_attr) + n_iter += 1 + await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter) + stop_reason = stop_reason or StopReason.out_of_tokens # If tool calls are parsed successfully, @@ -662,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin): if n_iter >= self.agent_config.max_infer_iters: log.info("Done with MAX iterations, exiting.") + # NOTE: mark end_of_turn to indicate to client that we are done with the turn + # Do not continue the tool call loop after this point + message.stop_reason = StopReason.end_of_turn yield message break @@ -711,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin): # If tool is a client tool, yield CompletionMessage and return if tool_call.tool_name in client_tools: + # NOTE: mark end_of_message to indicate to client that it may + # call the tool and continue the conversation with the tool's response. + message.stop_reason = StopReason.end_of_message await self.storage.set_in_progress_tool_call_step( session_id, turn_id, @@ -796,8 +806,6 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message, result_message] - n_iter += 1 - async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None ) -> Tuple[List[ToolDefinition], Dict[str, str]]: diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 9a2663567..2c04305fd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -105,3 +105,15 @@ class AgentPersistence: key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", ) return ToolExecutionStep(**json.loads(value)) if value else None + + async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): + await self.kvstore.set( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + value=str(num_infer_iters), + ) + + async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: + value = await self.kvstore.get( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + ) + return int(value) if value else None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 201dcc1d4..4d808da23 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str +def test_custom_tool_infinite_loop(llama_stack_client, agent_config): + client_tool = get_boiling_point + agent_config = { + **agent_config, + "instructions": "You are a helpful assistant Always respond with tool calls no matter what. ", + "client_tools": [client_tool.get_tool_definition()], + "max_infer_iters": 5, + } + + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Get the boiling point of polyjuice with a tool call.", + }, + ], + session_id=session_id, + stream=False, + ) + + num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) + assert num_tool_calls <= 5 + + def test_tool_choice(llama_stack_client, agent_config): def run_agent(tool_choice): client_tool = get_boiling_point