forked from phoenix-oss/llama-stack-mirror
feat: unify max_infer_iters in client/server agent loop (#1309)
# What does this PR do? We currently use `max_infer_iters` in 2 different ways 1/ Server: track number of times 2/ Client side: track number of times we send `resume_turn` request This PR gets rid of the need of (2) and makes server track total number of times we perform inference within a Turn **NOTE** The PR will assume StopReason is set to - end_of_message: turn is not finished, we could be waiting for client tool call responses - end_of_turn: if the entire turn is finished and there's no more things to be done. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_custom_tool_infinite_loop --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` [//]: # (## Documentation)
This commit is contained in:
parent
754feba61f
commit
7d111c7510
3 changed files with 50 additions and 3 deletions
|
@ -105,3 +105,15 @@ class AgentPersistence:
|
|||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue