forked from phoenix-oss/llama-stack-mirror
fix: don't include tool args not in the function definition (#1307)
# Summary: Right now we would include toolgroup args when we encode messages with tool_calls, which is confusing the model since they not in the function description (see test plan for example). # Test Plan: Add a print statement before raw prompt is sent to providers (no good way to test this currently) Before: ``` cated in the same neighborhood?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n[knowledge_search(query="Laleli Mosque and Esma Sultan Mansion same neighborhood", vector_db_ids=["829a68735d744dc3830409dcc782964a"])]<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n\nknowledge_search tool found 5 chunks:\nBEGIN of ``` Note the extra `vector_db_ids` After ``` >user<|end_header_id|>\n\nAre the Laleli Mosque and Esma Sultan Mansion located in the same neighborhood?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n[knowledge_search(query="Laleli Mosque and Esma Sultan Mansion same neighborhood")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n\nknowledge_search tool found ```
This commit is contained in:
parent
663c6b0537
commit
a34f3aafcf
1 changed files with 6 additions and 7 deletions
|
@ -1054,9 +1054,6 @@ async def execute_tool_call_maybe(
|
|||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
tool_call_args = tool_call.arguments
|
||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
|
@ -1065,10 +1062,12 @@ async def execute_tool_call_maybe(
|
|||
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs=dict(
|
||||
session_id=session_id,
|
||||
**tool_call_args,
|
||||
),
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue