fix: don't include tool args not in the function definition (#1307)

# Summary:
Right now we would include toolgroup args when we encode messages with
tool_calls, which is confusing the model since they not in the function
description (see test plan for example).

# Test Plan:
Add a print statement before raw prompt is sent to providers (no good
way to test this currently)

Before:
```
cated in the same neighborhood?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n[knowledge_search(query="Laleli Mosque and Esma Sultan Mansion same neighborhood", vector_db_ids=["829a68735d744dc3830409dcc782964a"])]<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n\nknowledge_search tool found 5 chunks:\nBEGIN of
```
Note the extra `vector_db_ids`

After
```
>user<|end_header_id|>\n\nAre the Laleli Mosque and Esma Sultan Mansion located in the same neighborhood?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n[knowledge_search(query="Laleli Mosque and Esma Sultan Mansion same neighborhood")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n\nknowledge_search tool found
```
This commit is contained in:
ehhuang 2025-02-27 16:25:30 -08:00 committed by GitHub
parent 663c6b0537
commit a34f3aafcf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1054,9 +1054,6 @@ async def execute_tool_call_maybe(
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
@ -1065,10 +1062,12 @@ async def execute_tool_call_maybe(
result = await tool_runtime_api.invoke_tool(
tool_name=name,
kwargs=dict(
session_id=session_id,
**tool_call_args,
),
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**toolgroup_args.get(group_name, {}),
},
)
return result