feat: support '-' in tool names (#1807)

# What does this PR do?
titled

## Test Plan
added new unit tests
pytest -s -v tests/unit/models/llama/llama3/test_tool_utils.py
This commit is contained in:
ehhuang 2025-04-12 14:23:03 -07:00 committed by GitHub
parent ef3dc143ec
commit ad86a68a32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 275 additions and 76 deletions

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
@ -35,80 +28,141 @@ def is_json(s):
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
def parse_llama_tool_call_format(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
Parse tool calls in the format:
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
"""
# Strip outer brackets and whitespace
input_string = input_string.strip()
if not (input_string.startswith("[") and input_string.endswith("]")):
return None
content = input_string[1:-1].strip()
if not content:
return None
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# State variables for parsing
pos = 0
length = len(content)
# Extract keyword arguments
for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
while pos < length:
# Find function name
name_end = content.find("(", pos)
if name_end == -1:
break
result.append((function_name, function_args))
func_name = content[pos:name_end].strip()
return result
# Find closing parenthesis for this function call
paren_level = 1
args_start = name_end + 1
args_end = args_start
while args_end < length and paren_level > 0:
if content[args_end] == "(":
paren_level += 1
elif content[args_end] == ")":
paren_level -= 1
args_end += 1
if paren_level != 0:
# Unmatched parentheses
return None
# Parse arguments
args_str = content[args_start : args_end - 1].strip()
args_dict = {}
if args_str:
# Split by commas, but respect nested structures
parts = []
part_start = 0
in_quotes = False
quote_char = None
nested_level = 0
for i, char in enumerate(args_str):
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
quote_char = None
elif not in_quotes:
if char in ("{", "["):
nested_level += 1
elif char in ("}", "]"):
nested_level -= 1
elif char == "," and nested_level == 0:
parts.append(args_str[part_start:i].strip())
part_start = i + 1
parts.append(args_str[part_start:].strip())
# Process each key=value pair
for part in parts:
if "=" in part:
key, value = part.split("=", 1)
key = key.strip()
value = value.strip()
# Try to convert value to appropriate Python type
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
# String
value = value[1:-1]
elif value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
elif value.startswith("{") and value.endswith("}"):
# This is a nested dictionary
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
elif value.startswith("[") and value.endswith("]"):
# This is a nested list
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
else:
# Try to convert to number
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if not a valid number
pass
args_dict[key] = value
result.append((func_name, args_dict))
# Move to the next function call
pos = args_end
# Skip the comma between function calls if present
if pos < length and content[pos] == ",":
pos += 1
return result if result else None
class ToolUtils:
@ -156,11 +210,11 @@ class ToolUtils:
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
elif function_calls := parse_llama_tool_call_format(message_body):
# FIXME: Enable multiple tool calls
return res[0]
return function_calls[0]
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
return None
@staticmethod

View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
class TestMaybeExtractCustomToolCall:
def test_valid_single_tool_call(self):
input_string = '[get_weather(location="San Francisco", units="celsius")]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "get_weather"
assert result[1] == {"location": "San Francisco", "units": "celsius"}
def test_valid_multiple_tool_calls(self):
input_string = '[search(query="python programming"), get_time(timezone="UTC")]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
# Note: maybe_extract_custom_tool_call currently only returns the first tool call
assert result is not None
assert len(result) == 2
assert result[0] == "search"
assert result[1] == {"query": "python programming"}
def test_different_value_types(self):
input_string = '[analyze_data(count=42, enabled=True, ratio=3.14, name="test", options=None)]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "analyze_data"
assert result[1] == {"count": 42, "enabled": True, "ratio": 3.14, "name": "test", "options": None}
def test_nested_structures(self):
input_string = '[complex_function(filters={"min": 10, "max": 100}, tags=["important", "urgent"])]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
# This test checks that nested structures are handled
assert result is not None
assert len(result) == 2
assert result[0] == "complex_function"
assert "filters" in result[1]
assert sorted(result[1]["filters"].items()) == sorted({"min": 10, "max": 100}.items())
assert "tags" in result[1]
assert result[1]["tags"] == ["important", "urgent"]
def test_hyphenated_function_name(self):
input_string = '[weather-forecast(city="London")]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "weather-forecast" # Function name remains hyphenated
assert result[1] == {"city": "London"}
def test_empty_input(self):
input_string = "[]"
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is None
def test_invalid_format(self):
invalid_inputs = [
'get_weather(location="San Francisco")', # Missing outer brackets
'{get_weather(location="San Francisco")}', # Wrong outer brackets
'[get_weather(location="San Francisco"]', # Unmatched brackets
'[get_weather{location="San Francisco"}]', # Wrong inner brackets
"just some text", # Not a tool call format at all
]
for input_string in invalid_inputs:
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is None
def test_quotes_handling(self):
input_string = '[search(query="Text with \\"quotes\\" inside")]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
# This test checks that escaped quotes are handled correctly
assert result is not None
def test_single_quotes_in_arguments(self):
input_string = "[add-note(name='demonote', content='demonstrating Llama Stack and MCP integration')]"
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "add-note" # Function name remains hyphenated
assert result[1] == {"name": "demonote", "content": "demonstrating Llama Stack and MCP integration"}
def test_json_format(self):
input_string = '{"type": "function", "name": "search_web", "parameters": {"query": "AI research"}}'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "search_web"
assert result[1] == {"query": "AI research"}
def test_python_list_format(self):
input_string = "[calculate(x=10, y=20)]"
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "calculate"
assert result[1] == {"x": 10, "y": 20}
def test_complex_nested_structures(self):
input_string = '[advanced_query(config={"filters": {"categories": ["books", "electronics"], "price_range": {"min": 10, "max": 500}}, "sort": {"field": "relevance", "order": "desc"}})]'
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
assert result is not None
assert len(result) == 2
assert result[0] == "advanced_query"
# Verify the overall structure
assert "config" in result[1]
assert isinstance(result[1]["config"], dict)
# Verify the first level of nesting
config = result[1]["config"]
assert "filters" in config
assert "sort" in config
# Verify the second level of nesting (filters)
filters = config["filters"]
assert "categories" in filters
assert "price_range" in filters
# Verify the list within the dict
assert filters["categories"] == ["books", "electronics"]
# Verify the nested dict within another dict
assert filters["price_range"]["min"] == 10
assert filters["price_range"]["max"] == 500
# Verify the sort dictionary
assert config["sort"]["field"] == "relevance"
assert config["sort"]["order"] == "desc"