mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
max infer iters
This commit is contained in:
parent
663c6b0537
commit
af254f8cf3
2 changed files with 14 additions and 1 deletions
|
@ -535,7 +535,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
|
@ -783,6 +783,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
|
|
|
@ -105,3 +105,15 @@ class AgentPersistence:
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
||||||
|
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=str(num_infer_iters),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return int(value) if value else None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue