chore(api): add mypy coverage to tools_utils

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 20:11:25 +02:00
parent 51b179e1c5
commit 5c1671ac8c
2 changed files with 77 additions and 17 deletions

View file

@ -167,7 +167,7 @@ def parse_llama_tool_call_format(input_string):
class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod
@ -184,7 +184,7 @@ class ToolUtils:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, dict] | None:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
@ -197,7 +197,11 @@ class ToolUtils:
tool_name = match.group("function_name")
query = match.group("args")
try:
return tool_name, json.loads(query.replace("'", '"'))
parsed_args = json.loads(query.replace("'", '"'))
if isinstance(parsed_args, dict):
return tool_name, parsed_args
else:
return None
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
@ -208,12 +212,22 @@ class ToolUtils:
):
function_name = response["name"]
args = response["parameters"]
if isinstance(args, dict):
return function_name, args
else:
return None
else:
return None
elif function_calls := parse_llama_tool_call_format(message_body):
# FIXME: Enable multiple tool calls
return function_calls[0]
if function_calls and len(function_calls) > 0:
func_name, args_dict = function_calls[0]
if isinstance(args_dict, dict):
return func_name, args_dict
else:
return None
else:
return None
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
return None
@ -221,28 +235,63 @@ class ToolUtils:
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search:
if isinstance(t.arguments, dict):
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
else:
# Handle string arguments
return f'brave_search.call(query="{t.arguments}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
if isinstance(t.arguments, dict):
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
else:
# Handle string arguments
return f'wolfram_alpha.call(query="{t.arguments}")'
elif t.tool_name == BuiltinTool.photogen:
if isinstance(t.arguments, dict):
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
else:
# Handle string arguments
return f'photogen.call(query="{t.arguments}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
if isinstance(t.arguments, dict):
code = t.arguments["code"]
return str(code)
else:
# Handle string arguments
return str(t.arguments)
else:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
# For JSON format, we need to handle both string and dict arguments
if isinstance(t.arguments, str):
# Try to parse string as JSON
try:
parsed_args = json.loads(t.arguments)
except json.JSONDecodeError:
parsed_args = {"value": t.arguments}
else:
parsed_args = t.arguments
return json.dumps(
{
"type": "function",
"name": fname,
"parameters": t.arguments,
"parameters": parsed_args,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
if isinstance(t.arguments, str):
# Try to parse string as JSON
try:
parsed_args = json.loads(t.arguments)
args = json.dumps(parsed_args)
except json.JSONDecodeError:
args = json.dumps({"value": t.arguments})
else:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
@ -256,10 +305,21 @@ class ToolUtils:
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"
elif isinstance(value, dict):
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
return f"{{{', '.join(f'"{k}"={format_value(v)}' for k, v in value.items())}}}"
else:
raise ValueError(f"Unsupported type: {type(value)}")
if isinstance(t.arguments, str):
# Try to parse string as JSON
try:
parsed_args = json.loads(t.arguments)
if isinstance(parsed_args, dict):
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in parsed_args.items())
else:
args_str = f"value={format_value(parsed_args)}"
except json.JSONDecodeError:
args_str = f'value="{t.arguments}"'
else:
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
return f"[{fname}({args_str})]"
else:

View file

@ -245,7 +245,7 @@ exclude = [
"^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",