This commit is contained in:
Ashwin Bharambe 2025-10-01 16:01:09 -07:00
parent 6749c853c0
commit ae14966204
6 changed files with 3221 additions and 54 deletions

View file

@ -220,17 +220,18 @@ class ToolUtils:
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
args = json.loads(t.arguments)
if t.tool_name == BuiltinTool.brave_search:
q = t.arguments["query"]
q = args["query"]
return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
q = t.arguments["query"]
q = args["query"]
return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen:
q = t.arguments["query"]
q = args["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
return args["code"]
else:
fname = t.tool_name
@ -239,12 +240,11 @@ class ToolUtils:
{
"type": "function",
"name": fname,
"parameters": t.arguments,
"parameters": args,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
return f"<function={fname}>{t.arguments}</function>"
elif tool_prompt_format == ToolPromptFormat.python_list:
@ -260,7 +260,7 @@ class ToolUtils:
else:
raise ValueError(f"Unsupported type: {type(value)}")
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in args.items())
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")

View file

@ -875,12 +875,18 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
try:
args = json.loads(tool_call.arguments)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
)