mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
feat: unify max_infer_iters in client/server agent loop (#1309)
# What does this PR do? We currently use `max_infer_iters` in 2 different ways 1/ Server: track number of times 2/ Client side: track number of times we send `resume_turn` request This PR gets rid of the need of (2) and makes server track total number of times we perform inference within a Turn **NOTE** The PR will assume StopReason is set to - end_of_message: turn is not finished, we could be waiting for client tool call responses - end_of_turn: if the entire turn is finished and there's no more things to be done. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py::test_custom_tool_infinite_loop --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` [//]: # (## Documentation)
This commit is contained in:
parent
754feba61f
commit
7d111c7510
3 changed files with 50 additions and 3 deletions
|
@ -540,7 +540,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
|
@ -627,6 +628,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
|
|
||||||
# If tool calls are parsed successfully,
|
# If tool calls are parsed successfully,
|
||||||
|
@ -662,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
log.info("Done with MAX iterations, exiting.")
|
log.info("Done with MAX iterations, exiting.")
|
||||||
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
|
# Do not continue the tool call loop after this point
|
||||||
|
message.stop_reason = StopReason.end_of_turn
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -711,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
|
# NOTE: mark end_of_message to indicate to client that it may
|
||||||
|
# call the tool and continue the conversation with the tool's response.
|
||||||
|
message.stop_reason = StopReason.end_of_message
|
||||||
await self.storage.set_in_progress_tool_call_step(
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
session_id,
|
session_id,
|
||||||
turn_id,
|
turn_id,
|
||||||
|
@ -796,8 +806,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
|
|
|
@ -105,3 +105,15 @@ class AgentPersistence:
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
||||||
|
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=str(num_infer_iters),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return int(value) if value else None
|
||||||
|
|
|
@ -278,6 +278,33 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
||||||
|
client_tool = get_boiling_point
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
||||||
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
|
"max_infer_iters": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Get the boiling point of polyjuice with a tool call.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps])
|
||||||
|
assert num_tool_calls <= 5
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice(llama_stack_client, agent_config):
|
def test_tool_choice(llama_stack_client, agent_config):
|
||||||
def run_agent(tool_choice):
|
def run_agent(tool_choice):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue