From af254f8cf3dbd63c7c613b67a91bc8bdae047183 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 16:30:45 -0800 Subject: [PATCH] max infer iters --- .../inline/agents/meta_reference/agent_instance.py | 3 ++- .../inline/agents/meta_reference/persistence.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 3502c21f2..79919adcd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -535,7 +535,7 @@ class ChatAgent(ShieldRunnerMixin): output_attachments = [] - n_iter = 0 + n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup client_tools = {} for tool in self.agent_config.client_tools: @@ -783,6 +783,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message, result_message] n_iter += 1 + await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter) async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 9a2663567..2c04305fd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -105,3 +105,15 @@ class AgentPersistence: key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", ) return ToolExecutionStep(**json.loads(value)) if value else None + + async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): + await self.kvstore.set( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + value=str(num_infer_iters), + ) + + async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: + value = await self.kvstore.get( + key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", + ) + return int(value) if value else None