This commit is contained in:
Xi Yan 2025-02-27 16:47:08 -08:00
parent af254f8cf3
commit c5b6e3845f
2 changed files with 9 additions and 3 deletions

View file

@ -536,6 +536,8 @@ class ChatAgent(ShieldRunnerMixin):
output_attachments = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
print("initial n_iter", n_iter)
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
@ -613,6 +615,11 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
n_iter += 1
print(f"n_iter after update: {n_iter}")
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
print(self.agent_config.max_infer_iters)
stop_reason = stop_reason or StopReason.out_of_tokens
# If tool calls are parsed successfully,
@ -782,9 +789,6 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message]
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:

View file

@ -107,6 +107,7 @@ class AgentPersistence:
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
print(f"setting num_infer_iters_in_turn: {num_infer_iters}")
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
@ -116,4 +117,5 @@ class AgentPersistence:
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
print(f"getting num_infer_iters_in_turn: {value}")
return int(value) if value else None